diff --git a/.github/workflows/gradle-all.yml b/.github/workflows/gradle-all.yml new file mode 100644 index 000000000..abbd14106 --- /dev/null +++ b/.github/workflows/gradle-all.yml @@ -0,0 +1,152 @@ +name: Branches Java CI + +on: + # Trigger the workflow on push + # but only for the non master/1.0.x branches + push: + branches-ignore: + - 1.1.x + - master + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Publish Packages to Artifactory + if: ${{ matrix.jdk == '1.8' }} + run: | + githubRef="${githubRef#refs/heads/}" + githubRef="${githubRef////-}" + ./gradlew -PversionSuffix="-${githubRef}-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToGitHubPackagesRepository --no-daemon --stacktrace + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + githubRef: ${{ github.ref }} + buildNumber: ${{ github.run_number }} \ No newline at end of file diff --git a/.github/workflows/gradle-main.yml b/.github/workflows/gradle-main.yml new file mode 100644 index 000000000..33bca8e72 --- /dev/null +++ b/.github/workflows/gradle-main.yml @@ -0,0 +1,161 @@ +name: Main Branches Java CI + +on: + # Trigger the workflow on push + # but only for the master/1.1.x branch + push: + branches: + - master + - 1.1.x + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Publish Packages to Artifactory + if: ${{ matrix.jdk == '1.8' }} + run: ./gradlew -PversionSuffix="-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToSonatypeRepository --no-daemon --stacktrace + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} + - name: Aggregate test reports with ciMate + if: always() + continue-on-error: true + env: + CIMATE_PROJECT_ID: m84qx17y + run: | + wget -q https://get.cimate.io/release/linux/cimate + chmod +x cimate + ./cimate "**/TEST-*.xml" \ No newline at end of file diff --git a/.github/workflows/gradle-pr.yml b/.github/workflows/gradle-pr.yml new file mode 100644 index 000000000..cecca085f --- /dev/null +++ b/.github/workflows/gradle-pr.yml @@ -0,0 +1,111 @@ +name: Pull Request Java CI + +on: [pull_request] + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon \ No newline at end of file diff --git a/.github/workflows/gradle-release.yml b/.github/workflows/gradle-release.yml new file mode 100644 index 000000000..922eb0e3e --- /dev/null +++ b/.github/workflows/gradle-release.yml @@ -0,0 +1,44 @@ +name: Release Java CI + +on: + # Trigger the workflow on push + push: + # Sequence of patterns matched against refs/tags + tags: + - '*' # Push events to matching *, i.e. 1.0, 20.15.10 + +jobs: + publish: + + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK 1.8 + uses: actions/setup-java@v1 + with: + java-version: 1.8 + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test + - name: Publish Packages to Sonotype + run: ./gradlew -Pversion="${githubRef#refs/tags/}" -PbuildNumber="${buildNumber}" sign publishMavenPublicationToSonatypeRepository + env: + githubRef: ${{ github.ref }} + buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} \ No newline at end of file diff --git a/.gitignore b/.gitignore index bde7e8f50..92865ccca 100644 --- a/.gitignore +++ b/.gitignore @@ -65,7 +65,7 @@ atlassian-ide-plugin.xml # NetBeans specific files/directories .nbattrs -/bin +**/bin/* #.gitignore in subdirectory .gitignore diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 89f3d4b42..000000000 --- a/.travis.yml +++ /dev/null @@ -1,30 +0,0 @@ -language: java -jdk: -- oraclejdk8 - -# force upgrade Java8 as per https://github.com/travis-ci/travis-ci/issues/4042 (fixes compilation issue) -addons: - apt: - packages: - - oracle-java8-installer - -dist: trusty -#group: edge -sudo: false -# as per http://blog.travis-ci.com/2014-12-17-faster-builds-with-container-based-infrastructure/ - -# script for build and release via Travis to Bintray -script: gradle/buildViaTravis.sh - -# cache between builds -cache: - directories: - - $HOME/.m2 - - $HOME/.gradle - -env: - global: - - secure: "WBCy0hsF96Xybj4n0AUrGY2m5FWCUa30XR+aVElSOO8d7v7BMypAT8mAd+yC2Y+j8WUGpIv59CqgeK1JrYdR9b3qRKhJmoE1Q92TotrxXMTIC9OKuU51LaaOqGYqx4SqiA2AyaikTFPd8um7KZfUpW/dG4IXySsiJ2OKT1jMUq6TmbWHnAYtjbl3u3WdjBQTIZNMtqG1+H1vIpsWyZrvbB4TWlNzhKBAu/YnlzMtvStrDaF7XrCJ2BQdMomQO18NH2gWxUEvLbQb6ip3wFl9CRe6vID7K1dmFwm08RPt9hRPC9yDahlIy8VvuNcWrP42TV+BVYy8V/hfaIo1pPsDBrtmVyc7YZjXSUM68orDFOkRB35qGkNIaAhy5Yt6G9QfwLXJkDFofW5KMKtDFUzf+j4DwS0CiDMF4k6Qq7YN1tYFXE9R8xa6Gv+wTNHqs4RURbYMS9IlbkhKxNbtyuema2sIUbsIfDezIzLI5BnfH2uli7O6/z0/G0Vfmf6A4q5Olm+7uhzMTI0GKheUIKr16SOxABlrwJtLJftzoKz9hYd3b7C9t61vYzccC3rWYobplwIcK2w50gFHQS8HLeiCjo8yjCx+IRSvAGaZIBPQdHCktrEYCVDUTXOxdaD6k6Ef+ppm8Nn+M+iC8x/G1wYE4x1lDqHw3GfhKsEQmiHL/98=" - - secure: "mbB+rv9eWUFQ9/yr2REH2ztH6r/Uq7cq/OJ5WK6yFp0TmPzlJ8jbEVwe/sdAMW2E4qrfMu1c2h3qsVm41pNx0MwEsIW/lTIZRiRmNYon32n+SHlRWyTn8dJeY/p1HoHs450OjLgB4X4jmRmfSt8IQ/w9ZCjF6HVcgR4ctt+myECTNcRidEIOahljnSJmnFFDsKbt2UJN96AfvvhbxcarEKgKLXLd9tQT2GlvEOM+hVOY9hKD5FvIoRp9heyCEAsSBXe+MIWQlh4jx+B4zCajZJ+8KN6M8KIt40lV8z4Zbc11jgq/xULJwkQIuVZvkJ3huIfUrxwLPgYWeai/TR/m3+2jy1hFajt96pnhJzFEz0IBL0wFALwAY1n2R/6uugEUYnDsFcGQGTsO5OeeOixiRPH5HNgfOhInqJoFh/887f+gq7OLXjlRCTsw+S9KknZ3iBpHX/+khurfAUC9khiMvufEq6Wyu0TvxhmGERFrs7uugeJ1VA85SDVQ6Au9MV831PeBGqzHpYG7w2kJj1EiFjBRUhCthxyDfX2b04egozlKF8JEifZ9EVj7pNMQUvVG2c9Wj6M0fG84NusnlZlA16XxAmfLevc9b/BOSSrqc2r9Z1ZvxFnBPP9H94Uqt9ZninhW/T49jRF+lQzD45MTVogzVk77XtdpzUemf4t5mHc=" - - secure: "GcPu3U4o2Dp7QLCqaAo3mGMJTl9yd+w+elXqqt7WDjrjm5p8mrzvQfyiJA7mRJVDTGpgib8fLctL1X1+QOX4fNKElrDUFhE3bWAqwVwHGPK4D3HCb6THD5XVqE4qcPmdLWPkvJ9ZY5nSIfuRVASjZTcc4XSXISK2jUSGar0PNYlo62/OFGvNvMz/qINU9RU7iYdDlL19yd72TKDfuK0UOKhQEGypamEHam3SMNCw/p8Q5K1vQe+Oba3ILCvYHJvqWc2NLjRXJjXfIaOq/NpCK6Lx2U9etdpkb5lyW5Cx1lkzIcRUq8ZUCwbkHog9LJoZGrZFh5AzlZ6kRuejBqu7AISmZy4s9HVAb7AQmNxvXkK9EIt8lavcaHnLYUIfuxvBqK/ptcUN5P/KXCs1DsbpADjB7YbUu/EQ2OAWncV31Z+O4uMHV29eGTtaz9LoK28+mHRfFHqoazWyuUejor6iSSkrCeqsLEvU8o6rH4oenKz7hLlZsJqHGACYtYNYi2CXYlTu0bMX+Hb1EtTu6Awm9Gn04TqVdmNexgF5CdqW4A696i6jlkPpVCt4B4nq4VPs2RMTkjVl3B7uOkDm18u35dncuhgsnMfVmo9cWX5COeyefdh6kdnKsUf0+IPbV/hix/OCP72dpuhxgcyzN+DvaVLzX7YOx7TpJTzPSKNEQZc=" - - secure: "UFJEzDEv6H2Qscg9UgZFVJq5oFvq7nQkVoSuGfh5Y4ZhL9PCK5f3Ft9oYEZOQwXaxWD1qivtJjQV3DdBiqsHkrnPrJ0hi3iYVDJo26xLNtu3welFw5Veqmgu2NuwjaDn6cjRFCJRLzpszMUWO1DvfLJTs3LuJDuXEyAKDw9eQgfOakqO4xeloyXgM7xnoXz11rgqtJNU6snjVPHftXNPTHGsNDlTR7SAIbjYwLMbdIKM2qjzrXkg+a94QOz2stnTDz9V5iYNH+3XXCcYxD9nb1Ol1XGWvtDnNGEhtGmylLdjHXwGLHiW2HOXskLzSkm7ASie1WdyHVHZb4X8LjxCy62S0FPevBgat1a443Khx5HCMYR/8dQrlOI82GYTr8n9U6QQE4Li8XLw64DVP9HGs9jdbsfEdlIsiPWqB6ujlwiO6pyfmQGQCgjALA+oD87uDQLcgh+SDYgE0ZwmwGzbjeynZpoCrEE8A1GHhSwkM9khx6EJFacm9XzqoUGK0wB1f8su+51fqPglF1zye80IFA4wOMMAY+KUc9du/vQ98f0lfjsNSOC02CKYxbA5RaakQMAYjirsZraA57xLmCSIGMhhW4wClQdJBww6LLz463yZU4WPwyqU+ZW12aV5dVLb5RWXIbZKmdT74DfZajHvqgTYpb05L5cJl7ApMspUkKk=" diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 000000000..ef7dd9dda --- /dev/null +++ b/AUTHORS @@ -0,0 +1,21 @@ +benjchristensen = Ben Christensen +gregwhitaker = Greg Whitaker +junaidkhalid = Junaid Khalid +kojilin = Kang-Sze Lin +krisskross = Kristoffer Sjogren +ktoso = Konrad Malawski +lehecka = Ondrej Lehecka +lexs = Alexander Blom +mostroverkhov = Maksym Ostroverkhov +nebhale = Ben Hale +NiteshKant = Nitesh Kant +qweek = Alex Novoselov +rdegnan = Ryland Degnan +robertroeser = Robert Roeser +rstoyanchev = Rossen Stoyanchev +simonbasle = Simon Baslé +somasun = somasun +stevegury = Steve Gury +tmontgomery = Todd L. Montgomery +yschimke = Yuri Schimke +OlegDokuka = Oleh Dokuka diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 47e7b87eb..56a5a7b69 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,22 +6,22 @@ When submitting code, please make every effort to follow existing conventions an ## License -By contributing your code, you agree to license your contribution under the terms of the APLv2: https://github.com/RSocket/reactivesocket-java/blob/master/LICENSE +By contributing your code, you agree to license your contribution under the terms of the APLv2: https://github.com/rsocket/rsocket-java/blob/1.0.x/LICENSE All files are released with the Apache 2.0 license. If you are adding a new file it should have a header like this: -``` -/** - * Copyright 2015 Netflix, Inc. - * +```java +/* + * Copyright 2015-2018 the original author or authors. + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/NOTICE b/NOTICE new file mode 100644 index 000000000..ea8e324f1 --- /dev/null +++ b/NOTICE @@ -0,0 +1,15 @@ +RSocket Java + +Copyright 2015-2018 the original author or authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md index 5f5d38adb..7ed3244b8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # RSocket -[![Join the chat at https://gitter.im/RSocket/reactivesocket-java](https://badges.gitter.im/RSocket/reactivesocket-java.svg)](https://gitter.im/RSocket/reactivesocket-java?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +[![Join the chat at https://gitter.im/RSocket/RSocket-Java](https://badges.gitter.im/rsocket/rsocket-java.svg)](https://gitter.im/rsocket/rsocket-java) RSocket is a binary protocol for use on byte stream transports such as TCP, WebSockets, and Aeron. @@ -15,23 +15,40 @@ Learn more at http://rsocket.io ## Build and Binaries - +[![Build Status](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml/badge.svg?branch=master)](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml) -Snapshots are available via JFrog. +⚠️ The `master` branch is now dedicated to development of the `1.2.x` line. + +Releases and milestones are available via Maven Central. Example: ```groovy repositories { - maven { url 'https://oss.jfrog.org/libs-snapshot' } + mavenCentral() + maven { url 'https://repo.spring.io/milestone' } // Reactor milestones (if needed) +} +dependencies { + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' } +``` +Snapshots are available via [oss.jfrog.org](oss.jfrog.org) (OJO). + +Example: + +```groovy +repositories { + maven { url 'https://maven.pkg.github.com/rsocket/rsocket-java' } + maven { url 'https://repo.spring.io/snapshot' } // Reactor snapshots (if needed) +} dependencies { - compile 'io.rsocket:reactivesocket:0.9-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' } ``` -No releases to Maven Central or JCenter have occurred yet. ## Development @@ -54,14 +71,14 @@ Frames can be printed out to help debugging. Set the logger `io.rsocket.FrameLog ## Trivial Client -``` +```java package io.rsocket.transport.netty; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.DefaultPayload; import reactor.core.publisher.Flux; import java.net.URI; @@ -69,32 +86,60 @@ import java.net.URI; public class ExampleClient { public static void main(String[] args) { WebsocketClientTransport ws = WebsocketClientTransport.create(URI.create("ws://rsocket-demo.herokuapp.com/ws")); - RSocket client = RSocketFactory.connect().keepAlive().transport(ws).start().block(); + RSocket clientRSocket = RSocketConnector.connectWith(ws).block(); try { - Flux s = client.requestStream(PayloadImpl.textPayload("peace")); + Flux s = clientRSocket.requestStream(DefaultPayload.create("peace")); s.take(10).doOnNext(p -> System.out.println(p.getDataUtf8())).blockLast(); } finally { - client.close().block(); + clientRSocket.dispose(); } } } ``` +## Zero Copy +By default to make RSocket easier to use it copies the incoming Payload. Copying the payload comes at cost to performance +and latency. If you want to use zero copy you must disable this. To disable copying you must include a `payloadDecoder` +argument in your `RSocketFactory`. This will let you manage the Payload without copying the data from the underlying +transport. You must free the Payload when you are done with them +or you will get a memory leak. Used correctly this will reduce latency and increase performance. + +### Example Server setup +```java +RSocketServer.create(new PingHandler()) + // Enable Zero Copy + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(7878)) + .block() + .onClose() + .block(); +``` + +### Example Client setup +```java +RSocket clientRSocket = + RSocketConnector.create() + // Enable Zero Copy + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(TcpClientTransport.create(7878)) + .block(); +``` + ## Bugs and Feedback For bugs, questions and discussions please use the [Github Issues](https://github.com/RSocket/reactivesocket-java/issues). ## LICENSE -Copyright 2015 Netflix, Inc. +Copyright 2015-2020 the original author or authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..656e2de4b --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,47 @@ +## Usage of JMH tasks + +Only execute specific benchmark(s) (wildcards are added before and after): +``` +../gradlew jmh --include="(BenchmarkPrimary|OtherBench)" +``` +If you want to specify the wildcards yourself, you can pass the full regexp: +``` +../gradlew jmh --fullInclude=.*MyBenchmark.* +``` + +Specify extra profilers: +``` +../gradlew jmh --profilers="gc,stack" +``` + +Prominent profilers (for full list call `jmhProfilers` task): +- comp - JitCompilations, tune your iterations +- stack - which methods used most time +- gc - print garbage collection defaultWeightedStats +- hs_thr - thread usage + +Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: +``` +./gradlew jmh --format=csv +``` + +Specify JVM arguments: +``` +../gradlew jmh --jvmArgs="-Dtest.cluster=local" +``` + +Run in verification mode (execute benchmarks with minimum of fork/warmup-/benchmark-iterations): +``` +../gradlew jmh --verify=true +``` + +## Comparing with the baseline +If you wish you run two sets of benchmarks, one for the current change and another one for the "baseline", +there is an additional task `jmhBaseline` that will use the latest release: +``` +../gradlew jmh jmhBaseline --include=MyBenchmark +``` + +## Resources +- http://tutorials.jenkov.com/java-performance/jmh.html (Introduction) +- http://hg.openjdk.java.net/code-tools/jmh/file/tip/jmh-samples/src/main/java/org/openjdk/jmh/samples/ (Samples) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle new file mode 100644 index 000000000..74e571d1f --- /dev/null +++ b/benchmarks/build.gradle @@ -0,0 +1,170 @@ +apply plugin: 'java' +apply plugin: 'idea' + +configurations { + current + baseline { + resolutionStrategy.cacheChangingModulesFor 0, 'seconds' + } +} + +dependencies { + // Use the baseline to avoid using new APIs in the benchmarks + compileOnly "io.rsocket:rsocket-core:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-local:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-netty:${perfBaselineVersion}" + + implementation "org.openjdk.jmh:jmh-core:1.35" + annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.35" + + current project(':rsocket-core') + current project(':rsocket-transport-local') + current project(':rsocket-transport-netty') + baseline "io.rsocket:rsocket-core:${perfBaselineVersion}", { + changing = true + } + baseline "io.rsocket:rsocket-transport-local:${perfBaselineVersion}", { + changing = true + } +} + +task jmhProfilers(type: JavaExec, description:'Lists the available profilers for the jmh task', group: 'Development') { + classpath = sourceSets.main.runtimeClasspath + main = 'org.openjdk.jmh.Main' + args '-lprof' +} + +task jmh(type: JmhExecTask, description: 'Executing JMH benchmarks') { + main = 'org.openjdk.jmh.Main' + classpath = sourceSets.main.runtimeClasspath + configurations.current +} + +task jmhBaseline(type: JmhExecTask, description: 'Executing JMH baseline benchmarks') { + main = 'org.openjdk.jmh.Main' + classpath = sourceSets.main.runtimeClasspath + configurations.baseline +} + +clean { + delete "${projectDir}/src/main/generated" +} + +class JmhExecTask extends JavaExec { + + private String include; + private String fullInclude; + private String exclude; + private String format = "json"; + private String profilers; + private String jmhJvmArgs; + private String verify; + + public JmhExecTask() { + super(); + } + + public String getInclude() { + return include; + } + + @Option(option = "include", description="configure bench inclusion using substring") + public void setInclude(String include) { + this.include = include; + } + + public String getFullInclude() { + return fullInclude; + } + + @Option(option = "fullInclude", description = "explicitly configure bench inclusion using full JMH style regexp") + public void setFullInclude(String fullInclude) { + this.fullInclude = fullInclude; + } + + public String getExclude() { + return exclude; + } + + @Option(option = "exclude", description = "explicitly configure bench exclusion using full JMH style regexp") + public void setExclude(String exclude) { + this.exclude = exclude; + } + + public String getFormat() { + return format; + } + + @Option(option = "format", description = "configure report format") + public void setFormat(String format) { + this.format = format; + } + + public String getProfilers() { + return profilers; + } + + @Option(option = "profilers", description = "configure jmh profiler(s) to use, comma separated") + public void setProfilers(String profilers) { + this.profilers = profilers; + } + + public String getJmhJvmArgs() { + return jmhJvmArgs; + } + + @Option(option = "jvmArgs", description = "configure additional JMH JVM arguments, comma separated") + public void setJmhJvmArgs(String jvmArgs) { + this.jmhJvmArgs = jvmArgs; + } + + public String getVerify() { + return verify; + } + + @Option(option = "verify", description = "run in verify mode") + public void setVerify(String verify) { + this.verify = verify; + } + + @TaskAction + public void exec() { + File resultFile = getProject().file("build/reports/" + getName() + "/result." + format); + + if (include != null) { + args(".*" + include + ".*"); + } + else if (fullInclude != null) { + args(fullInclude); + } + + if(exclude != null) { + args("-e", exclude); + } + if(verify != null) { // execute benchmarks with the minimum amount of execution (only to check if they are working) + System.out.println("Running in verify mode"); + args("-f", 1); + args("-wi", 1); + args("-i", 1); + } + args("-foe", "true"); //fail-on-error + args("-v", "NORMAL"); //verbosity [SILENT, NORMAL, EXTRA] + if(profilers != null) { + for (String prof : profilers.split(",")) { + args("-prof", prof); + } + } + args("-jvmArgsPrepend", "-Xmx3072m"); + args("-jvmArgsPrepend", "-Xms3072m"); + if(jmhJvmArgs != null) { + for(String jvmArg : jmhJvmArgs.split(" ")) { + args("-jvmArgsPrepend", jvmArg); + } + } + args("-rf", format); + args("-rff", resultFile); + + System.out.println("\nExecuting JMH with: " + getArgs() + "\n"); + resultFile.getParentFile().mkdirs(); + + super.exec(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java new file mode 100644 index 000000000..2e6fa6acc --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java @@ -0,0 +1,37 @@ +package io.rsocket; + +import java.util.concurrent.CountDownLatch; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +public class MaxPerfSubscriber extends CountDownLatch implements CoreSubscriber { + + final Blackhole blackhole; + + public MaxPerfSubscriber(Blackhole blackhole) { + super(1); + this.blackhole = blackhole; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(T payload) { + blackhole.consume(payload); + } + + @Override + public void onError(Throwable t) { + blackhole.consume(t); + countDown(); + } + + @Override + public void onComplete() { + countDown(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java new file mode 100644 index 000000000..7a7a1fdd6 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsMaxPerfSubscriber extends MaxPerfSubscriber { + + public PayloadsMaxPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java new file mode 100644 index 000000000..efc116958 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsPerfSubscriber extends PerfSubscriber { + + public PayloadsPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java new file mode 100644 index 000000000..92577d95c --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java @@ -0,0 +1,41 @@ +package io.rsocket; + +import java.util.concurrent.CountDownLatch; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +public class PerfSubscriber extends CountDownLatch implements CoreSubscriber { + + final Blackhole blackhole; + + Subscription s; + + public PerfSubscriber(Blackhole blackhole) { + super(1); + this.blackhole = blackhole; + } + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + s.request(1); + } + + @Override + public void onNext(T payload) { + blackhole.consume(payload); + s.request(1); + } + + @Override + public void onError(Throwable t) { + blackhole.consume(t); + countDown(); + } + + @Override + public void onComplete() { + countDown(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java new file mode 100644 index 000000000..4437400c4 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java @@ -0,0 +1,226 @@ +package io.rsocket.core; + +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.PayloadsMaxPerfSubscriber; +import io.rsocket.PayloadsPerfSubscriber; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.lang.reflect.Field; +import java.util.Queue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.LockSupport; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Publisher; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@BenchmarkMode({Mode.Throughput, Mode.SampleTime}) +@Fork(value = 2) +@Warmup(iterations = 10) +@Measurement(iterations = 10, time = 10) +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +public class RSocketPerf { + + @Param({"tcp", "websocket", "local"}) + String transportType; + + @Param({"0", "64", "1024", "131072", "1048576", "15728640"}) + String payloadSize; + + Payload payload; + Mono payloadMono; + Flux payloadsFlux; + + RSocket client; + Closeable server; + Queue clientsQueue; + + @TearDown + public void tearDown() { + client.dispose(); + server.dispose(); + payload.release(); + } + + @TearDown(Level.Iteration) + public void awaitToBeConsumed() { + while (!clientsQueue.isEmpty()) { + LockSupport.parkNanos(1000); + } + } + + @Setup + public void setUp() throws NoSuchFieldException, IllegalAccessException, ClassNotFoundException { + ClientTransport clientTransport; + ServerTransport serverTransport; + switch (transportType) { + case "tcp": + clientTransport = TcpClientTransport.create(8081); + serverTransport = TcpServerTransport.create(8081); + break; + case "websocket": + clientTransport = WebsocketClientTransport.create(8081); + serverTransport = WebsocketServerTransport.create(8081); + break; + case "local": + default: + clientTransport = LocalClientTransport.create("server"); + serverTransport = LocalServerTransport.create("server"); + break; + } + Payload payload; + int payloadSize = Integer.parseInt(this.payloadSize); + if (payloadSize == 0) { + payload = EmptyPayload.INSTANCE; + } else { + byte[] randomMetadata = new byte[payloadSize / 2]; + byte[] randomData = new byte[payloadSize / 2]; + ThreadLocalRandom.current().nextBytes(randomData); + ThreadLocalRandom.current().nextBytes(randomMetadata); + + payload = ByteBufPayload.create(randomData, randomMetadata); + } + + this.payload = payload; + this.payloadMono = Mono.fromSupplier(payload::retain); + this.payloadsFlux = Flux.range(0, 100000).map(__ -> payload.retain()); + this.server = + RSocketServer.create( + (setup, sendingSocket) -> + Mono.just( + new RSocket() { + + @Override + public Mono fireAndForget(Payload payload) { + payload.release(); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return payloadMono; + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return payloadsFlux; + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(serverTransport) + .block(); + + this.client = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport) + .block(); + + try { + Field sendProcessorField = RSocketRequester.class.getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); + } catch (Throwable t) { + Field sendProcessorField = + Class.forName("io.rsocket.core.RequesterResponderSupport") + .getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); + } + } + + @Benchmark + @SuppressWarnings("unchecked") + public PayloadsPerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.fireAndForget(payload.retain()).subscribe((CoreSubscriber) subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestResponse(payload.retain()).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestStream(payload.retain()).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsMaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); + client.requestStream(payload.retain()).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestChannel(payloadsFlux).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsMaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); + client.requestChannel(payloadsFlux).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } +} diff --git a/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java new file mode 100644 index 000000000..402cdb353 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java @@ -0,0 +1,55 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class FrameHeaderCodecPerf { + + @Benchmark + public void encode(Input input) { + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(input.allocator, FrameType.SETUP, 0); + boolean release = byteBuf.release(); + input.bh.consume(release); + } + + @Benchmark + public void decode(Input input) { + ByteBuf frame = input.frame; + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + int flags = FrameHeaderCodec.flags(frame); + input.bh.consume(streamId); + input.bh.consume(flags); + input.bh.consume(frameType); + } + + @State(Scope.Benchmark) + public static class Input { + Blackhole bh; + FrameType frameType; + ByteBufAllocator allocator; + ByteBuf frame; + + @Setup + public void setup(Blackhole bh) { + this.bh = bh; + this.frameType = FrameType.REQUEST_RESPONSE; + allocator = ByteBufAllocator.DEFAULT; + frame = FrameHeaderCodec.encode(allocator, 123, FrameType.SETUP, 0); + } + + @TearDown + public void teardown() { + frame.release(); + } + } +} diff --git a/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java new file mode 100644 index 000000000..efa22104f --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java @@ -0,0 +1,38 @@ +package io.rsocket.frame; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class FrameTypePerf { + @Benchmark + public void lookup(Input input) { + FrameType frameType = input.frameType; + boolean b = + frameType.canHaveData() + && frameType.canHaveMetadata() + && frameType.isFragmentable() + && frameType.isRequestType() + && frameType.hasInitialRequestN(); + + input.bh.consume(b); + } + + @State(Scope.Benchmark) + public static class Input { + Blackhole bh; + FrameType frameType; + + @Setup + public void setup(Blackhole bh) { + this.bh = bh; + this.frameType = FrameType.REQUEST_RESPONSE; + } + } +} diff --git a/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java b/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java new file mode 100644 index 000000000..ead1c2fa3 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java @@ -0,0 +1,77 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class PayloadFrameCodecPerf { + + @Benchmark + public void encode(Input input) { + ByteBuf encode = + PayloadFrameCodec.encode( + input.allocator, + 100, + false, + true, + false, + Unpooled.wrappedBuffer(input.metadata), + Unpooled.wrappedBuffer(input.data)); + boolean release = encode.release(); + input.bh.consume(release); + } + + @Benchmark + public void decode(Input input) { + ByteBuf frame = input.payload; + ByteBuf data = PayloadFrameCodec.data(frame); + ByteBuf metadata = PayloadFrameCodec.metadata(frame); + input.bh.consume(data); + input.bh.consume(metadata); + } + + @State(Scope.Benchmark) + public static class Input { + Blackhole bh; + FrameType frameType; + ByteBufAllocator allocator; + ByteBuf payload; + byte[] metadata = new byte[512]; + byte[] data = new byte[4096]; + + @Setup + public void setup(Blackhole bh) { + this.bh = bh; + this.frameType = FrameType.REQUEST_RESPONSE; + allocator = ByteBufAllocator.DEFAULT; + + // Encode a payload and then copy it a single bytebuf + payload = allocator.buffer(); + ByteBuf encode = + PayloadFrameCodec.encode( + allocator, + 100, + false, + true, + false, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)); + payload.writeBytes(encode); + encode.release(); + } + + @TearDown + public void teardown() { + payload.release(); + } + } +} diff --git a/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java new file mode 100644 index 000000000..8f429fc19 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java @@ -0,0 +1,96 @@ +package io.rsocket.metadata; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork(value = 1) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class WellKnownMimeTypePerf { + + // this is the old values() looping implementation of fromIdentifier + private WellKnownMimeType fromIdValuesLoop(int id) { + if (id < 0 || id > 127) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (value.getIdentifier() == id) { + return value; + } + } + return WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE; + } + + // this is the core of the old values() looping implementation of fromString + private WellKnownMimeType fromStringValuesLoop(String mimeType) { + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (mimeType.equals(value.getString())) { + return value; + } + } + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + + @Benchmark + public void fromIdArrayLookup(final Blackhole bh) { + // negative lookup + bh.consume(WellKnownMimeType.fromIdentifier(-10)); + bh.consume(WellKnownMimeType.fromIdentifier(-1)); + // too large lookup + bh.consume(WellKnownMimeType.fromIdentifier(129)); + // first lookup + bh.consume(WellKnownMimeType.fromIdentifier(0)); + // middle lookup + bh.consume(WellKnownMimeType.fromIdentifier(37)); + // reserved lookup + bh.consume(WellKnownMimeType.fromIdentifier(63)); + // last lookup + bh.consume(WellKnownMimeType.fromIdentifier(127)); + } + + @Benchmark + public void fromIdValuesLoopLookup(final Blackhole bh) { + // negative lookup + bh.consume(fromIdValuesLoop(-10)); + bh.consume(fromIdValuesLoop(-1)); + // too large lookup + bh.consume(fromIdValuesLoop(129)); + // first lookup + bh.consume(fromIdValuesLoop(0)); + // middle lookup + bh.consume(fromIdValuesLoop(37)); + // reserved lookup + bh.consume(fromIdValuesLoop(63)); + // last lookup + bh.consume(fromIdValuesLoop(127)); + } + + @Benchmark + public void fromStringMapLookup(final Blackhole bh) { + // unknown lookup + bh.consume(WellKnownMimeType.fromString("foo/bar")); + // first lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + WellKnownMimeType.fromString( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } + + @Benchmark + public void fromStringValuesLoopLookup(final Blackhole bh) { + // unknown lookup + bh.consume(fromStringValuesLoop("foo/bar")); + // first lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + fromStringValuesLoop(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } +} diff --git a/build.gradle b/build.gradle index 9ee18cea3..2971a7767 100644 --- a/build.gradle +++ b/build.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,189 +15,276 @@ */ plugins { - id 'com.gradle.build-scan' version '1.9' // declare before any other plugin - - id 'com.github.sherter.google-java-format' version '0.6' - id 'com.github.johnrengelman.shadow' version '2.0.1' apply false - id 'me.champeau.gradle.jmh' version '0.4.4' apply false - id 'io.morethan.jmhreport' version '0.6.2.1' apply false - - id 'com.jfrog.artifactory' version '4.5.2' - id 'com.jfrog.bintray' version '1.7.3' + id 'com.github.sherter.google-java-format' version '0.9' apply false + id 'me.champeau.jmh' version '0.7.1' apply false + id 'io.spring.dependency-management' version '1.1.0' apply false + id 'io.morethan.jmhreport' version '0.9.0' apply false + id 'io.github.reyerizo.gradle.jcstress' version '0.8.15' apply false + id 'com.github.vlsi.gradle-extensions' version '1.89' apply false } -repositories { - jcenter() +boolean isCiServer = ["CI", "CONTINUOUS_INTEGRATION", "TRAVIS", "CIRCLECI", "bamboo_planKey", "GITHUB_ACTION"].with { + retainAll(System.getenv().keySet()) + return !isEmpty() } -description = 'RSocket: stream oriented messaging passing with Reactive Stream semantics.' +subprojects { + apply plugin: 'io.spring.dependency-management' + apply plugin: 'com.github.sherter.google-java-format' + apply plugin: 'com.github.vlsi.gradle-extensions' -buildScan { licenseAgreementUrl = 'https://gradle.com/terms-of-service'; licenseAgree = 'yes' } + ext['reactor-bom.version'] = '2022.0.7-SNAPSHOT' + ext['logback.version'] = '1.2.13' + ext['netty-bom.version'] = '4.1.117.Final' + ext['netty-boringssl.version'] = '2.0.69.Final' + ext['hdrhistogram.version'] = '2.1.12' + ext['mockito.version'] = '4.11.0' + ext['slf4j.version'] = '1.7.36' + ext['jmh.version'] = '1.36' + ext['junit.version'] = '5.9.3' + ext['micrometer.version'] = '1.11.12' + ext['micrometer-tracing.version'] = '1.1.13' + ext['assertj.version'] = '3.24.2' + ext['netflix.limits.version'] = '0.3.6' + ext['bouncycastle-bcpkix.version'] = '1.70' + ext['awaitility.version'] = '4.2.0' -googleJavaFormat { - toolVersion = '1.4' -} + group = "io.rsocket" -subprojects { - apply plugin: 'java' - apply plugin: 'maven' - apply plugin: 'maven-publish' - apply plugin: 'com.jfrog.bintray' - apply plugin: 'com.jfrog.artifactory' - - group = 'io.rsocket' - version = mavenversion - - compileJava { - sourceCompatibility = 1.8 - targetCompatibility = 1.8 - options.compilerArgs << '-Xlint:all,-overloads,-rawtypes,-unchecked' + googleJavaFormat { + toolVersion = '1.6' } ext { - // common - jsr305Version = '3.0.2' - reactorVersion = '3.1.0.RELEASE' - nettyVersion = '4.1.15.Final' - reactiveStreamsVersion = '1.0.1' - slf4jVersion = '1.7.25' - // aeron - aeronVersion = '1.4.1' - // netty - reactorNettyVersion = '0.7.0.M2' - // spectator - spectatorVersion = '0.57.1' - hdrHistogramVersion = '2.1.9' - // tck-drivers - jacksonVersion = '2.9.1' - commonsLang3Version = '3.6' - airlineVersion = '0.8' - rxjavaVersion = '2.1.3' - // test - junitVersion = '4.12' - hamcrestVersion = '1.3' - mockitoVersion = '2.10.0' - jmhVersion = '1.19' - } - - // custom tasks for creating source/javadoc jars - task sourcesJar(type: Jar, dependsOn: classes) { - classifier = 'sources' - from sourceSets.main.allSource + if (project.hasProperty('versionSuffix')) { + project.version += project.getProperty('versionSuffix') + } } - task javadocJar(type: Jar, dependsOn: javadoc) { - classifier = 'javadoc' - from javadoc.destinationDir + configurations.all { + resolutionStrategy.cacheChangingModulesFor 60, "minutes" } - tasks.bintrayUpload.dependsOn tasks.jar, tasks.sourcesJar, tasks.javadocJar + dependencyManagement { + imports { + mavenBom "io.projectreactor:reactor-bom:${ext['reactor-bom.version']}" + mavenBom "io.netty:netty-bom:${ext['netty-bom.version']}" + mavenBom "org.junit:junit-bom:${ext['junit.version']}" + mavenBom "io.micrometer:micrometer-bom:${ext['micrometer.version']}" + mavenBom "io.micrometer:micrometer-tracing-bom:${ext['micrometer-tracing.version']}" + } - // add javadoc/source jar tasks as artifacts - artifacts { - archives sourcesJar, javadocJar, jar + dependencies { + dependency "com.netflix.concurrency-limits:concurrency-limits-core:${ext['netflix.limits.version']}" + dependency "ch.qos.logback:logback-classic:${ext['logback.version']}" + dependency "io.netty:netty-tcnative-boringssl-static:${ext['netty-boringssl.version']}" + dependency "org.bouncycastle:bcpkix-jdk15on:${ext['bouncycastle-bcpkix.version']}" + dependency "org.assertj:assertj-core:${ext['assertj.version']}" + dependency "org.hdrhistogram:HdrHistogram:${ext['hdrhistogram.version']}" + dependency "org.slf4j:slf4j-api:${ext['slf4j.version']}" + dependency "org.awaitility:awaitility:${ext['awaitility.version']}" + dependencySet(group: 'org.mockito', version: ext['mockito.version']) { + entry 'mockito-junit-jupiter' + entry 'mockito-core' + } + dependencySet(group: 'org.openjdk.jmh', version: ext['jmh.version']) { + entry 'jmh-core' + entry 'jmh-generator-annprocess' + } + } + generatedPomCustomization { + enabled = false + } } repositories { - maven { url 'http://repo.spring.io/milestone' } - maven { url 'https://oss.jfrog.org/libs-snapshot' } - maven { url 'https://dl.bintray.com/rsocket/RSocket' } - maven { url 'https://dl.bintray.com/reactivesocket/ReactiveSocket' } + mavenCentral() + + maven { + url 'https://repo.spring.io/milestone' + content { + includeGroup "io.micrometer" + includeGroup "io.projectreactor" + includeGroup "io.projectreactor.netty" + includeGroup "io.micrometer" + } + } + + maven { + url 'https://repo.spring.io/snapshot' + content { + includeGroup "io.micrometer" + includeGroup "io.projectreactor" + includeGroup "io.projectreactor.netty" + } + } + + if (version.endsWith('SNAPSHOT') || project.hasProperty('versionSuffix')) { + maven { url 'https://repo.spring.io/libs-snapshot' } + maven { url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' } + mavenLocal() + } } - dependencies { - compile "io.projectreactor:reactor-core:3.1.0.RELEASE" - compile "io.netty:netty-buffer:4.1.15.Final" - compile "org.reactivestreams:reactive-streams:1.0.1" - compile "org.slf4j:slf4j-api:1.7.25" - compile "com.google.code.findbugs:jsr305:3.0.2" - - testCompile "junit:junit:4.12" - testCompile "org.mockito:mockito-core:2.10.0" - testCompile "org.hamcrest:hamcrest-library:1.3" - testCompile "org.slf4j:slf4j-log4j12:1.7.25" - testCompile "io.projectreactor:reactor-test:3.1.0.RELEASE" + tasks.withType(GenerateModuleMetadata) { + enabled = false } - publishing { - publications { - mavenJava(MavenPublication) { - from components.java + plugins.withType(JavaPlugin) { - artifact sourcesJar { - classifier "sources" - } + compileJava { + sourceCompatibility = 1.8 - artifact javadocJar { - classifier "javadoc" - } + // TODO: Cleanup warnings so no need to exclude + options.compilerArgs << '-Xlint:all,-overloads,-rawtypes,-unchecked' + } + + javadoc { + def jdk = JavaVersion.current().majorVersion + def jdkJavadoc = "https://docs.oracle.com/javase/$jdk/docs/api/" + if (JavaVersion.current().isJava11Compatible()) { + jdkJavadoc = "https://docs.oracle.com/en/java/javase/$jdk/docs/api/" + } + options.with { + links jdkJavadoc + links 'https://projectreactor.io/docs/core/release/api/' + links 'https://netty.io/4.1/api/' } + failOnError = false } - } - artifactory { - publish { - contextUrl = 'https://oss.jfrog.org' - - repository { - repoKey = 'oss-snapshot-local' //The Artifactory repository key to publish to - //when using oss.jfrog.org the credentials are from Bintray. For local build we expect them to be found in - //~/.gradle/gradle.properties, otherwise to be set in the build server - // Conditionalize for the users who don't have bintray credentials setup - if (project.hasProperty('bintrayUser')) { - username = project.property('bintrayUser') - password = project.property('bintrayKey') + tasks.named("javadoc").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + + test { + useJUnitPlatform() + testLogging { + events "PASSED", "FAILED" + showExceptions true + showCauses true + exceptionFormat "FULL" + stackTraceFilters "ENTRY_POINT" + maxGranularity 3 + } + + //show progress by displaying test classes, avoiding test suite timeouts + TestDescriptor last + afterTest { TestDescriptor td, TestResult tr -> + if (last != td.getParent()) { + last = td.getParent() + println last } } - publications('mavenJava') + if (isCiServer) { + def stdout = new LinkedList() + beforeTest { TestDescriptor td -> + stdout.clear() + } + onOutput { TestDescriptor td, TestOutputEvent toe -> + stdout.add(toe) + } + afterTest { TestDescriptor td, TestResult tr -> + if (tr.resultType == TestResult.ResultType.FAILURE && stdout.size() > 0) { + def stdOutput = stdout.collect { + it.getDestination().name() == "StdErr" + ? "STD_ERR: ${it.getMessage()}" + : "STD_OUT: ${it.getMessage()}" + } + .join() + println "This is the console output of the failing test below:\n$stdOutput" + } + } + + reports { + junitXml.outputPerTestCase = true + } + } + + if (JavaVersion.current().isJava9Compatible()) { + println "Java 9+: lowering MaxGCPauseMillis to 20ms in ${project.name} ${name}" + println "Java 9+: enabling leak detection [ADVANCED]" + jvmArgs = ["-XX:MaxGCPauseMillis=20", "-Dio.netty.leakDetection.level=ADVANCED", "-Dio.netty.leakDetection.samplingInterval=32"] + } - defaults { - // Reference to Gradle publications defined in the build script. - // This is how we tell the Artifactory Plugin which artifacts should be - // published to Artifactory. - publications('mavenJava') - publishArtifacts = true + systemProperty("java.awt.headless", "true") + systemProperty("testGroups", project.properties.get("testGroups")) + + //allow re-run of failed tests only without special test tasks failing + // because the filter is too restrictive + filter.setFailOnNoMatchingTests(false) + + //display intermediate results for special test tasks + afterSuite { desc, result -> + if (!desc.parent) { // will match the outermost suite + println('\n' + "${desc} Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} successes, ${result.failedTestCount} failures, ${result.skippedTestCount} skipped)") + } } } } - artifactoryPublish { - dependsOn jar - } + plugins.withType(JavaLibraryPlugin) { + task sourcesJar(type: Jar) { + classifier 'sources' + from sourceSets.main.allJava + } - bintray { - if (project.hasProperty('bintrayUser')) { - user = project.property('bintrayUser') - key = project.property('bintrayKey') + task javadocJar(type: Jar, dependsOn: javadoc) { + classifier 'javadoc' + from javadoc.destinationDir } - publications = ['mavenJava'] - dryRun = false - publish = true - override = false - pkg { - repo = 'RSocket' - name = 'rsocket-java' - desc = 'RSocket' - websiteUrl = 'https://github.com/rsocket/rsocket-java' - issueTrackerUrl = 'https://github.com/rsocket/rsocket-java' - vcsUrl = 'https://github.com/rsocket/rsocket-java.git' - licenses = ['Apache-2.0'] - githubRepo = 'rsocket/rsocket-java' //Optional Github repository - githubReleaseNotesFile = 'README.md' //Optional Github readme file - if (project.hasProperty('sonatypeUsername') && project.hasProperty('sonatypePassword')) { - def sonatypeUsername = project.property('sonatypeUsername') - def sonatypePassword = project.property('sonatypePassword') - version { - name = "v${project.version}" - vcsTag = "${project.version}" - mavenCentralSync { - sync = false - user = sonatypeUsername - password = sonatypePassword + + plugins.withType(MavenPublishPlugin) { + publishing { + publications { + maven(MavenPublication) { + from components.java + artifact sourcesJar + artifact javadocJar } } } } } } + +apply from: "${rootDir}/gradle/publications.gradle" + +buildScan { + termsOfServiceUrl = 'https://gradle.com/terms-of-service' + termsOfServiceAgree = 'yes' +} + +description = 'RSocket: Stream Oriented Messaging Passing with Reactive Stream Semantics.' + +repositories { + mavenCentral() + + maven { url 'https://repo.spring.io/snapshot' } + mavenLocal() +} + +configurations { + adoc +} + +dependencies { + adoc "io.micrometer:micrometer-docs-generator-spans:1.0.0-SNAPSHOT" + adoc "io.micrometer:micrometer-docs-generator-metrics:1.0.0-SNAPSHOT" +} + +task generateObservabilityDocs(dependsOn: ["generateObservabilityMetricsDocs", "generateObservabilitySpansDocs"]) { +} + +task generateObservabilityMetricsDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.metrics.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() +} + +task generateObservabilitySpansDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.spans.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() +} diff --git a/gradle.properties b/gradle.properties index 75e1b324a..d138852c5 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,6 +1,4 @@ # -# Copyright 2016 Netflix, Inc. -# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,7 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -mavenversion=0.9-SNAPSHOT -release.scope=patch -release.version=0.9-SNAPSHOT +version=1.2.0 +perfBaselineVersion=1.1.4 diff --git a/gradle/buildViaTravis.sh b/gradle/buildViaTravis.sh deleted file mode 100755 index 94ebd3f86..000000000 --- a/gradle/buildViaTravis.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -# This script will build the project. - -if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - echo -e "Build Pull Request #$TRAVIS_PULL_REQUEST => Branch [$TRAVIS_BRANCH]" - ./gradlew build -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ]; then - echo -e 'Build Branch with Snapshot => Branch ['$TRAVIS_BRANCH']' - ./gradlew -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" build artifactoryPublish --stacktrace -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ]; then - echo -e 'Build Branch for Release => Branch ['$TRAVIS_BRANCH'] Tag ['$TRAVIS_TAG']' - ./gradlew -Pmavenversion="$TRAVIS_TAG" -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" build bintrayUpload --stacktrace -else - echo -e 'WARN: Should not be here => Branch ['$TRAVIS_BRANCH'] Tag ['$TRAVIS_TAG'] Pull Request ['$TRAVIS_PULL_REQUEST']' - ./gradlew build -fi diff --git a/gradle/github-pkg.gradle b/gradle/github-pkg.gradle new file mode 100644 index 000000000..f53413766 --- /dev/null +++ b/gradle/github-pkg.gradle @@ -0,0 +1,21 @@ +subprojects { + + plugins.withType(MavenPublishPlugin) { + publishing { + repositories { + maven { + name = "GitHubPackages" + url = uri("https://maven.pkg.github.com/rsocket/rsocket-java") + credentials { + username = project.findProperty("gpr.user") ?: System.getenv("GITHUB_ACTOR") + password = project.findProperty("gpr.key") ?: System.getenv("GITHUB_TOKEN") + } + } + } + } + + tasks.named("publish").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + } +} \ No newline at end of file diff --git a/gradle/publications.gradle b/gradle/publications.gradle new file mode 100644 index 000000000..9e8dd6d88 --- /dev/null +++ b/gradle/publications.gradle @@ -0,0 +1,53 @@ +apply from: "${rootDir}/gradle/github-pkg.gradle" +apply from: "${rootDir}/gradle/sonotype.gradle" + +subprojects { + plugins.withType(MavenPublishPlugin) { + publishing { + publications { + maven(MavenPublication) { + pom { + name = project.name + afterEvaluate { + description = project.description + } + groupId = 'io.rsocket' + url = 'http://rsocket.io' + licenses { + license { + name = "The Apache Software License, Version 2.0" + url = "https://www.apache.org/licenses/LICENSE-2.0.txt" + distribution = "repo" + } + } + developers { + developer { + id = 'OlegDokuka' + name = 'Oleh Dokuka' + email = 'oleh.dokuka@icloud.com' + } + developer { + id = 'rstoyanchev' + name = 'Rossen Stoyanchev' + email = 'rstoyanchev@vmware.com' + } + } + scm { + connection = 'scm:git:https://github.com/rsocket/rsocket-java.git' + developerConnection = 'scm:git:https://github.com/rsocket/rsocket-java.git' + url = 'https://github.com/rsocket/rsocket-java' + } + versionMapping { + usage('java-api') { + fromResolutionResult() + } + usage('java-runtime') { + fromResolutionResult() + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/gradle/sonotype.gradle b/gradle/sonotype.gradle new file mode 100644 index 000000000..f339079b0 --- /dev/null +++ b/gradle/sonotype.gradle @@ -0,0 +1,36 @@ +subprojects { + if (project.hasProperty('sonatypeUsername') && project.hasProperty('sonatypePassword')) { + plugins.withType(MavenPublishPlugin) { + plugins.withType(SigningPlugin) { + + signing { + //requiring signature if there is a publish task that is not to MavenLocal + required { gradle.taskGraph.allTasks.any { it.name.toLowerCase().contains("publish") && !it.name.contains("MavenLocal") } } + def signingKey = project.findProperty("signingKey") + def signingPassword = project.findProperty("signingPassword") + + useInMemoryPgpKeys(signingKey, signingPassword) + + afterEvaluate { + sign publishing.publications.maven + } + } + + publishing { + repositories { + maven { + name = "sonatype" + url = project.version.contains("-SNAPSHOT") + ? "https://oss.sonatype.org/content/repositories/snapshots/" + : "https://oss.sonatype.org/service/local/staging/deploy/maven2" + credentials { + username project.findProperty("sonatypeUsername") + password project.findProperty("sonatypePassword") + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index ed88a042a..249e5832f 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index c583957d2..774fae876 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-4.2.1-all.zip diff --git a/gradlew b/gradlew index cccdd3d51..a69d9cb6c 100755 --- a/gradlew +++ b/gradlew @@ -1,78 +1,129 @@ -#!/usr/bin/env sh +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +APP_BASE_NAME=${0##*/} # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -81,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -89,84 +140,101 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=$((i+1)) + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=$(save "$@") - -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" - -# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong -if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then - cd "$(dirname "$0")" +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" fi +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index e95643d6a..53a6b238d 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,4 +1,20 @@ -@if "%DEBUG%" == "" @echo off +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off @rem ########################################################################## @rem @rem Gradle startup script for Windows @@ -9,19 +25,22 @@ if "%OS%"=="Windows_NT" setlocal set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. +if "%DIRNAME%"=="" set DIRNAME=. set APP_BASE_NAME=%~n0 set APP_HOME=%DIRNAME% +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" @rem Find java.exe if defined JAVA_HOME goto findJavaFromJavaHome set JAVA_EXE=java.exe %JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init +if %ERRORLEVEL% equ 0 goto execute echo. echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. @@ -35,7 +54,7 @@ goto fail set JAVA_HOME=%JAVA_HOME:"=% set JAVA_EXE=%JAVA_HOME%/bin/java.exe -if exist "%JAVA_EXE%" goto init +if exist "%JAVA_EXE%" goto execute echo. echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% @@ -45,38 +64,26 @@ echo location of your Java installation. goto fail -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - :execute @rem Setup the command line set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + @rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* :end @rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd +if %ERRORLEVEL% equ 0 goto mainEnd :fail rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% :mainEnd if "%OS%"=="Windows_NT" endlocal diff --git a/rsocket-bom/build.gradle b/rsocket-bom/build.gradle new file mode 100755 index 000000000..a75ab3bc8 --- /dev/null +++ b/rsocket-bom/build.gradle @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'java-platform' + id 'maven-publish' + id 'signing' +} + +description = 'RSocket Java Bill of materials.' + +def excluded = ["rsocket-examples", "benchmarks"] + +dependencies { + constraints { + parent.subprojects.findAll { it.name != project.name && !excluded.contains(it.name) } .sort { "$it.name" }.each { + api it + } + } +} + +publishing { + publications { + maven(MavenPublication) { + from components.javaPlatform + } + } +} \ No newline at end of file diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle index c3a0645e8..da5b69b14 100644 --- a/rsocket-core/build.gradle +++ b/rsocket-core/build.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,56 +14,47 @@ * limitations under the License. */ -apply plugin: 'com.github.johnrengelman.shadow' -apply plugin: 'me.champeau.gradle.jmh' -apply plugin: 'io.morethan.jmhreport' - -// disable tasks to stop generating duplicate files -jmhClasses.enabled = false -jmhRunBytecodeGenerator.enabled = false -jmhCompileGeneratedClasses.enabled = false - -jmhJar { - // add jmh classes to shadow jar - from project.configurations.jmh - from project.configurations.jmhRuntime - - // exclude logging classes - exclude 'org.slf4j:slf4j-log4j12' - exclude 'log4j:log4j' +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' + id 'io.morethan.jmhreport' + id 'me.champeau.jmh' + id 'io.github.reyerizo.gradle.jcstress' } -jmh { - jmhVersion = '1.19' - includeTests = false - duplicateClassesStrategy = DuplicatesStrategy.WARN - zip64 = true - - jvmArgs = ['-XX:+UnlockCommercialFeatures', '-XX:+FlightRecorder'] - // NOTE: uncomment to add specific options - // jvmArgsAppend = ['-XX:+UseG1GC', '-Xms4g', '-Xmx4g'] - profilers = ['gc'] - resultFormat = 'JSON' - - // include = ['io.rsocket.RSocketPerf.fireAndForgetHello'] +dependencies { + api 'io.netty:netty-buffer' + api 'io.projectreactor:reactor-core' + + implementation 'org.slf4j:slf4j-api' + + testImplementation (project(":rsocket-transport-local")) + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' + testImplementation 'org.mockito:mockito-junit-jupiter' + testImplementation 'org.awaitility:awaitility' + + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' + + jcstressImplementation(project(":rsocket-test")) + jcstressImplementation 'org.slf4j:slf4j-api' + jcstressImplementation "ch.qos.logback:logback-classic" + jcstressImplementation 'io.projectreactor:reactor-test' } -// run report generation after benchmark -tasks.jmh.finalizedBy tasks.jmhReport - -jmhReport { - jmhResultPath = project.file('build/reports/jmh/results.json') - jmhReportOutput = project.file('build/reports/jmh') +jcstress { + mode = 'sanity' //sanity, quick, default, tough + jcstressDependency = "org.openjdk.jcstress:jcstress-core:0.16" } -// remove directory generated by IDEA during benchmark -clean { - project.delete('out') +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.core") + } } -dependencies { - jmh "org.openjdk.jmh:jmh-core:${jmh.jmhVersion}" - jmh "org.openjdk.jmh:jmh-generator-annprocess:${jmh.jmhVersion}" - - jmhRuntime "org.slf4j:slf4j-nop:$slf4jVersion" -} \ No newline at end of file +description = "Core functionality for the RSocket library" diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..e91be2451 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,115 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class FireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport(testDuplexConnection, StreamIdSupplier.clientSupplier()); + + final FireAndForgetRequesterMono source = source(); + + abstract FireAndForgetRequesterMono source(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java new file mode 100644 index 000000000..ef79d344d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java @@ -0,0 +1,604 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.ResolvingOperator.EMPTY_SUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.EMPTY_UNSUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.READY; +import static io.rsocket.core.ResolvingOperator.TERMINATED; +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.BiConsumer; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.IIIIIII_Result; +import org.openjdk.jcstress.infra.results.IIIIII_Result; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; + +public abstract class ReconnectMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscription stressSubscription = new StressSubscription<>(); + + final Mono source = source(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + volatile int onValueExpire; + + static final AtomicIntegerFieldUpdater ON_VALUE_EXPIRE = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueExpire"); + + volatile int onValueReceived; + + static final AtomicIntegerFieldUpdater ON_VALUE_RECEIVED = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueReceived"); + final ReconnectMono reconnectMono = + new ReconnectMono<>( + source, + (__) -> ON_VALUE_EXPIRE.incrementAndGet(BaseStressTest.this), + (__, ___) -> ON_VALUE_RECEIVED.incrementAndGet(BaseStressTest.this)); + + abstract Mono source(); + + int state() { + final BiConsumer[] subscribers = reconnectMono.resolvingInner.subscribers; + if (subscribers == EMPTY_UNSUBSCRIBED) { + return 0; + } else if (subscribers == EMPTY_SUBSCRIBED) { + return 1; + } else if (subscribers == READY) { + return 2; + } else if (subscribers == TERMINATED) { + return 3; + } else { + return 4; + } + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before value is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onComplete but before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after value is delivered") + @State + public static class ExpireValueOnRacingDisposeAndNext extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before error is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onError") + @State + public static class ExpireValueOnRacingDisposeAndError extends BaseStressTest { + + { + Hooks.onErrorDropped(t -> {}); + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onError(new RuntimeException("boom")); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + Hooks.resetOnErrorDropped(); + + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 0, 1, 2"}, + expect = ACCEPTABLE, + desc = "Invalidate happens before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Invalidate happens after value is delivered") + @State + public static class ExpireValueOnRacingInvalidateAndNextComplete extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndInvalidate extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate1() { + reconnectMono.invalidate(); + } + + @Actor + void invalidate2() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndDispose extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 2, 2, 0, 1"}, + expect = ACCEPTABLE) + @State + public static class DeliversValueToAllSubscribersUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNextAndComplete() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.requestsCount; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 1, 1, 1, 1, 4"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed after invalidate") + @Outcome( + id = {"1, 0, 2, 2, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed before invalidate and received value") + @State + public static class InvalidateAndSubscribeUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 2, 1, 2, 2"}, + expect = ACCEPTABLE, + desc = "Subscribed again after invalidate") + @Outcome( + id = {"1, 0, 1, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Subscribed before invalidate") + @State + public static class InvalidateAndBlockUnderRace extends BaseStressTest { + + String receivedValue; + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + receivedValue = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals("value1") ? 1 : receivedValue.equals("value2") ? 2 : -1; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRace extends BaseStressTest { + + StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void subscribe1() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void subscribe2() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.values.get(0).equals(stressSubscriber2.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class SubscribeBlockConnectRace extends BaseStressTest { + + String receivedValue; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block() { + receivedValue = reconnectMono.block(); + } + + @Actor + void subscribe() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void connect() { + reconnectMono.resolvingInner.connect(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals(stressSubscriber.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoBlocksRace extends BaseStressTest { + + String receivedValue1; + String receivedValue2; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block1() { + receivedValue1 = reconnectMono.block(); + } + + @Actor + void block2() { + receivedValue2 = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue1.equals(receivedValue2) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java new file mode 100644 index 000000000..1dde77b34 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java @@ -0,0 +1,650 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import java.util.stream.IntStream; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class RequestResponseRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(initialRequest()); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker; + + final TestRequesterResponderSupport requesterResponderSupport; + + final RequestResponseRequesterMono source; + + BaseStressTest(RequesterLeaseTracker requesterLeaseTracker) { + this.requesterLeaseTracker = requesterLeaseTracker; + this.requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + this.source = source(); + } + + abstract RequestResponseRequesterMono source(); + + abstract long initialRequest(); + } + + abstract static class BaseStressTestWithLease extends BaseStressTest { + + BaseStressTestWithLease(int maximumAllowedAwaitingPermitHandlersNumber) { + super(new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber)); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTestWithLease { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + public TwoSubscribesRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return Long.MAX_VALUE; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + this.source.handleNext(nextFrame, false, true); + nextFrame.release(); + + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + this.stressSubscriber1.values.forEach(Payload::release); + + r.r5 = this.source.payload.refCnt() + nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelRaceStressTest extends BaseStressTestWithLease { + + public SubscribeAndRequestAndCancelRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + public SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest() { + super(1); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "NoLeaseError delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancel extends BaseStressTest { + + SubscribeAndRequestAndCancel() { + super(null); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + r.r5 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundNextRaceStressTest extends BaseStressTestWithLease { + + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + + CancelWithInboundNextRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundNext() { + this.source.handleNext(this.nextFrame, false, true); + this.nextFrame.release(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt() + this.nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundCompleteRaceStressTest extends BaseStressTestWithLease { + + CancelWithInboundCompleteRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundComplete() { + this.source.handleComplete(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 2, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 3, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first. inbound error dropped") + @State + public static class CancelWithInboundErrorRaceStressTest extends BaseStressTestWithLease { + + static final RuntimeException ERROR = new RuntimeException("Test"); + + CancelWithInboundErrorRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundError() { + this.source.handleError(ERROR); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..5de7eb4b9 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,288 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLL_Result; + +public abstract class SlowFireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker = + new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber()); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + + final SlowFireAndForgetRequesterMono source = source(); + + abstract SlowFireAndForgetRequesterMono source(); + + abstract int maximumAllowedAwaitingPermitHandlersNumber(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelWithDeferredLeaseRaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 1; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "no lease error delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndCancelWithDeferredLease2RaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java new file mode 100644 index 000000000..883077f77 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java @@ -0,0 +1,472 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static reactor.core.publisher.Operators.addCap; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +public class StressSubscriber implements CoreSubscriber { + + enum Operation { + ON_NEXT, + ON_ERROR, + ON_COMPLETE, + ON_SUBSCRIBE + } + + final Context context; + final int requestedFusionMode; + + int fusionMode; + Subscription subscription; + + public Throwable error; + public boolean done; + + public List droppedErrors = new CopyOnWriteArrayList<>(); + + public List values = new ArrayList<>(); + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscriber.class, "requested"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "wip"); + + public volatile Operation guard; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater GUARD = + AtomicReferenceFieldUpdater.newUpdater(StressSubscriber.class, Operation.class, "guard"); + + public volatile boolean concurrentOnNext; + + public volatile boolean concurrentOnError; + + public volatile boolean concurrentOnComplete; + + public volatile boolean concurrentOnSubscribe; + + public volatile int onNextCalls; + + public BlockingQueue q = new LinkedBlockingDeque<>(); + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextCalls"); + + public volatile int onNextDiscarded; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_DISCARDED = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextDiscarded"); + + public volatile int onErrorCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_ERROR_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onErrorCalls"); + + public volatile int onCompleteCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_COMPLETE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onCompleteCalls"); + + public volatile int onSubscribeCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_SUBSCRIBE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onSubscribeCalls"); + + /** Build a {@link StressSubscriber} that makes an unbounded request upon subscription. */ + public StressSubscriber() { + this(Long.MAX_VALUE, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest) { + this(initRequest, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest, int requestedFusionMode) { + this.requestedFusionMode = requestedFusionMode; + this.context = + Operators.enableOnDiscard( + Context.of( + "reactor.onErrorDropped.local", + (Consumer) throwable -> droppedErrors.add(throwable)), + (__) -> ON_NEXT_DISCARDED.incrementAndGet(this)); + REQUESTED.lazySet(this, initRequest | Long.MIN_VALUE); + } + + @Override + public Context currentContext() { + return this.context; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (!GUARD.compareAndSet(this, null, Operation.ON_SUBSCRIBE)) { + concurrentOnSubscribe = true; + subscription.cancel(); + } else { + final boolean isValid = Operators.validate(this.subscription, subscription); + if (isValid) { + this.subscription = subscription; + } + GUARD.compareAndSet(this, Operation.ON_SUBSCRIBE, null); + + if (this.requestedFusionMode > 0 && subscription instanceof Fuseable.QueueSubscription) { + final int m = + ((Fuseable.QueueSubscription) subscription).requestFusion(this.requestedFusionMode); + final long requested = this.requested; + this.fusionMode = m; + if (m != Fuseable.NONE) { + if (requested == Long.MAX_VALUE) { + subscription.cancel(); + } + drain(); + return; + } + } + + if (isValid) { + long delivered = 0; + for (; ; ) { + long s = requested; + if (s == Long.MAX_VALUE) { + subscription.cancel(); + break; + } + + long r = s & Long.MAX_VALUE; + long toRequest = r - delivered; + if (toRequest > 0) { + subscription.request(toRequest); + delivered = r; + } + + if (REQUESTED.compareAndSet(this, s, 0)) { + break; + } + } + } + } + ON_SUBSCRIBE_CALLS.incrementAndGet(this); + } + + @Override + public void onNext(T value) { + if (fusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + if (!GUARD.compareAndSet(this, null, Operation.ON_NEXT)) { + concurrentOnNext = true; + } else { + values.add(value); + GUARD.compareAndSet(this, Operation.ON_NEXT, null); + } + ON_NEXT_CALLS.incrementAndGet(this); + } + + @Override + public void onError(Throwable throwable) { + if (!GUARD.compareAndSet(this, null, Operation.ON_ERROR)) { + concurrentOnError = true; + } else { + GUARD.compareAndSet(this, Operation.ON_ERROR, null); + } + + if (done) { + throw new IllegalStateException("Already done"); + } + + error = throwable; + done = true; + q.offer(throwable); + ON_ERROR_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + @Override + public void onComplete() { + if (!GUARD.compareAndSet(this, null, Operation.ON_COMPLETE)) { + concurrentOnComplete = true; + } else { + GUARD.compareAndSet(this, Operation.ON_COMPLETE, null); + } + if (done) { + throw new IllegalStateException("Already done"); + } + + done = true; + ON_COMPLETE_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + public void request(long n) { + if (Operators.validate(n)) { + for (; ; ) { + final long s = this.requested; + if (s == 0) { + this.subscription.request(n); + return; + } + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = addCap(r, n); + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + } + + public void cancel() { + for (; ; ) { + long s = this.requested; + if (s == 0) { + this.subscription.cancel(); + return; + } + + if (REQUESTED.compareAndSet(this, s, Long.MAX_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + + @SuppressWarnings("unchecked") + private void drain() { + final int previousState = markWorkAdded(); + if (isFinalized(previousState)) { + ((Queue) this.subscription).clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + final Subscription s = this.subscription; + final Queue q = (Queue) s; + + int expectedState = previousState + 1; + for (; ; ) { + long r = this.requested & Long.MAX_VALUE; + long e = 0L; + + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + boolean done = this.done; + + final T t = q.poll(); + final boolean empty = t == null; + + if (checkTerminated(done, empty)) { + if (!empty) { + values.add(t); + } + return; + } + + if (empty) { + break; + } + + values.add(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + boolean done = this.done; + boolean empty = q.isEmpty(); + + if (checkTerminated(done, empty)) { + return; + } + } + + if (e != 0) { + ON_NEXT_CALLS.addAndGet(this, (int) e); + if (r != Long.MAX_VALUE) { + produce(e); + } + } + + expectedState = markWorkDone(expectedState); + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + boolean checkTerminated(boolean done, boolean empty) { + final long state = this.requested; + if (state == Long.MAX_VALUE) { + this.subscription.cancel(); + clearAndFinalize(); + return true; + } + + if (done && empty) { + clearAndFinalize(); + return true; + } + + return false; + } + + final void produce(long produced) { + for (; ; ) { + final long s = this.requested; + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = r - produced; + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + return; + } + } + } + + @SuppressWarnings("unchecked") + final void clearAndFinalize() { + final Queue q = (Queue) this.subscription; + for (; ; ) { + final int state = this.wip; + + q.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + final int markWorkAdded() { + for (; ; ) { + final int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + int nextState = state + 1; + if ((nextState & Integer.MAX_VALUE) == 0) { + return state; + } + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + final int markWorkDone(int expectedState) { + for (; ; ) { + final int state = this.wip; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + if (WIP.compareAndSet(this, state, 0)) { + return 0; + } + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return (state & Integer.MAX_VALUE) > 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java new file mode 100644 index 000000000..3b51b8ef6 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Operators; + +public class StressSubscription implements Subscription { + + CoreSubscriber actual; + + public volatile int subscribes; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater SUBSCRIBES = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "subscribes"); + + public volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscription.class, "requested"); + + public volatile int requestsCount; + + @SuppressWarnings("rawtype s") + static final AtomicIntegerFieldUpdater REQUESTS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "requestsCount"); + + public volatile boolean cancelled; + + void subscribe(CoreSubscriber actual) { + this.actual = actual; + actual.onSubscribe(this); + SUBSCRIBES.getAndIncrement(this); + } + + @Override + public void request(long n) { + REQUESTS_COUNT.incrementAndGet(this); + Operators.addCap(REQUESTED, this, n); + } + + @Override + public void cancel() { + cancelled = true; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..420da66ba --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,39 @@ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import reactor.util.annotation.Nullable; + +public class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { + + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + public TestRequesterResponderSupport( + DuplexConnection connection, StreamIdSupplier streamIdSupplier) { + this(connection, streamIdSupplier, null); + } + + public TestRequesterResponderSupport( + DuplexConnection connection, + StreamIdSupplier streamIdSupplier, + @Nullable RequesterLeaseTracker requesterLeaseTracker) { + super( + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + PayloadDecoder.ZERO_COPY, + connection, + streamIdSupplier, + __ -> null); + this.requesterLeaseTracker = requesterLeaseTracker; + } + + @Override + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java new file mode 100644 index 000000000..22c478979 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java @@ -0,0 +1,155 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class UnpooledByteBufPayload extends AbstractReferenceCounted implements Payload { + + private final ByteBuf data; + private final ByteBuf metadata; + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data) { + return create(data, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload(ByteBufUtil.writeUtf8(allocator, data), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata) { + return create(data, metadata, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload( + ByteBufUtil.writeUtf8(allocator, data), + metadata == null ? null : ByteBufUtil.writeUtf8(allocator, metadata)); + } + + public UnpooledByteBufPayload(ByteBuf data, @Nullable ByteBuf metadata) { + this.data = data; + this.metadata = metadata; + } + + @Override + public boolean hasMetadata() { + ensureAccessible(); + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + ensureAccessible(); + return data; + } + + @Override + public ByteBuf metadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + ensureAccessible(); + return data.slice(); + } + + @Override + public UnpooledByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public UnpooledByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public UnpooledByteBufPayload touch() { + ensureAccessible(); + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public UnpooledByteBufPayload touch(Object hint) { + ensureAccessible(); + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + if (metadata != null) { + metadata.release(); + } + } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link UnpooledByteBufPayload#ensureAccessible()} to try to guard against + * using the buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java new file mode 100644 index 000000000..a2d9fcf4d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java @@ -0,0 +1,1733 @@ +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.core.StressSubscriber; +import io.rsocket.utils.FastLogger; +import java.util.Arrays; +import java.util.ConcurrentModificationException; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.Expect; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; +import org.openjdk.jcstress.infra.results.LLL_Result; +import org.openjdk.jcstress.infra.results.L_Result; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.util.Logger; + +public abstract class UnboundedProcessorStressTest { + + static { + Hooks.onErrorDropped(t -> {}); + } + + final Logger logger = new FastLogger(getClass().getName()); + + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(logger); + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeFusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.onCompleteCalls > 0 && stressSubscriber.onErrorCalls > 0) { + throw new RuntimeException("boom"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke24StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class Smoke21FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with onComplete()") + @State + public static class Smoke30StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke31StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.concurrentOnNext || stressSubscriber.concurrentOnComplete) { + throw new ConcurrentModificationException("boo"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke32StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0, 5", + "1, 1, 0, 5", + "2, 1, 0, 5", + "3, 1, 0, 5", + "4, 1, 0, 5", + "5, 1, 0, 5", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke33StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + final ByteBuf byteBuf5 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(5); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.tryEmitNormal(byteBuf1); + unboundedProcessor.tryEmitPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.tryEmitPrioritized(byteBuf3); + unboundedProcessor.tryEmitNormal(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.tryEmitFinal(byteBuf5); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + r.r4 = stressSubscriber.values.get(stressSubscriber.values.size() - 1).readByte(); + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = + byteBuf1.refCnt() + + byteBuf2.refCnt() + + byteBuf3.refCnt() + + byteBuf4.refCnt() + + byteBuf5.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"1"}, + expect = Expect.ACCEPTABLE) + @State + public static class SubscribeVsSubscribeStressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(0, Fuseable.NONE); + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(0, Fuseable.NONE); + + @Actor + public void subscribe1() { + unboundedProcessor.subscribe(stressSubscriber1); + } + + @Actor + public void subscribe2() { + unboundedProcessor.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(L_Result r) { + r.r1 = stressSubscriber1.onErrorCalls + stressSubscriber2.onErrorCalls; + + checkOutcomes(this, r.toString(), logger); + } + } + + static void checkOutcomes(Object instance, String result, Logger logger) { + if (Arrays.stream(instance.getClass().getDeclaredAnnotationsByType(Outcome.class)) + .flatMap(o -> Arrays.stream(o.id())) + .noneMatch(s -> s.equalsIgnoreCase(result))) { + throw new RuntimeException(result + " " + logger); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java new file mode 100644 index 000000000..f0b209552 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java @@ -0,0 +1,118 @@ +package io.rsocket.resume; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LL_Result; +import reactor.core.Disposable; + +public class InMemoryResumableFramesStoreStressTest { + boolean storeClosed; + + InMemoryResumableFramesStore store = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 128); + boolean processorClosed; + UnboundedProcessor processor = new UnboundedProcessor(() -> processorClosed = true); + + void subscribe() { + store.saveFrames(processor).subscribe(); + store.onClose().subscribe(null, t -> storeClosed = true, () -> storeClosed = true); + } + + @JCStressTest + @Outcome( + id = {"true, true"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends InMemoryResumableFramesStoreStressTest { + + Disposable d1; + + final ByteBuf b1 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello1"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello2")); + final ByteBuf b2 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 3, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello3"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello4")); + final ByteBuf b3 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 5, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello5"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello6")); + + final ByteBuf c1 = + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new ConnectionErrorException("closed")); + + { + subscribe(); + d1 = store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer1() { + processor.tryEmitNormal(b1); + processor.tryEmitNormal(b2); + processor.tryEmitNormal(b3); + } + + @Actor + public void producer2() { + processor.tryEmitFinal(c1); + } + + @Actor + public void producer3() { + d1.dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer4() { + store.releaseFrames(0); + store.releaseFrames(0); + store.releaseFrames(0); + } + + @Arbiter + public void arbiter(LL_Result r) { + r.r1 = storeClosed; + r.r2 = processorClosed; + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java new file mode 100644 index 000000000..c301d87cf --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java @@ -0,0 +1,137 @@ +package io.rsocket.utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import reactor.util.Logger; + +/** + * Implementation of {@link Logger} which is based on the {@link ThreadLocal} based queue which + * collects all the events on the per-thread basis.
Such logger is designed to have all events + * stored during the stress-test run and then sorted and printed out once all the Threads completed + * execution (inside the {@link org.openjdk.jcstress.annotations.Arbiter} annotated method.
+ * Note, this implementation only supports trace-level logs and ignores all others, it is intended + * to be used by {@link reactor.core.publisher.StateLogger}. + */ +public class FastLogger implements Logger { + + final Map> queues = new ConcurrentHashMap<>(); + + final ThreadLocal> logsQueueLocal = + ThreadLocal.withInitial( + () -> { + final ArrayList logs = new ArrayList<>(100); + queues.put(Thread.currentThread(), logs); + return logs; + }); + + private final String name; + + public FastLogger(String name) { + this.name = name; + } + + @Override + public String toString() { + return queues + .values() + .stream() + .flatMap(List::stream) + .sorted( + Comparator.comparingLong( + s -> { + Pattern pattern = Pattern.compile("\\[(.*?)]"); + Matcher matcher = pattern.matcher(s); + matcher.find(); + return Long.parseLong(matcher.group(1)); + })) + .collect(Collectors.joining("\n")); + } + + @Override + public String getName() { + return this.name; + } + + @Override + public boolean isTraceEnabled() { + return true; + } + + @Override + public void trace(String msg) { + logsQueueLocal.get().add(String.format("[%s] %s", System.nanoTime(), msg)); + } + + @Override + public void trace(String format, Object... arguments) { + trace(String.format(format, arguments)); + } + + @Override + public void trace(String msg, Throwable t) { + trace(String.format("%s, %s", msg, Arrays.toString(t.getStackTrace()))); + } + + @Override + public boolean isDebugEnabled() { + return false; + } + + @Override + public void debug(String msg) {} + + @Override + public void debug(String format, Object... arguments) {} + + @Override + public void debug(String msg, Throwable t) {} + + @Override + public boolean isInfoEnabled() { + return false; + } + + @Override + public void info(String msg) {} + + @Override + public void info(String format, Object... arguments) {} + + @Override + public void info(String msg, Throwable t) {} + + @Override + public boolean isWarnEnabled() { + return false; + } + + @Override + public void warn(String msg) {} + + @Override + public void warn(String format, Object... arguments) {} + + @Override + public void warn(String msg, Throwable t) {} + + @Override + public boolean isErrorEnabled() { + return false; + } + + @Override + public void error(String msg) {} + + @Override + public void error(String format, Object... arguments) {} + + @Override + public void error(String msg, Throwable t) {} +} diff --git a/rsocket-core/src/jcstress/resources/logback.xml b/rsocket-core/src/jcstress/resources/logback.xml new file mode 100644 index 000000000..e5877552c --- /dev/null +++ b/rsocket-core/src/jcstress/resources/logback.xml @@ -0,0 +1,39 @@ + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/rsocket-core/src/jmh/java/io/rsocket/FragmentationPerf.java b/rsocket-core/src/jmh/java/io/rsocket/FragmentationPerf.java deleted file mode 100644 index 22ee60095..000000000 --- a/rsocket-core/src/jmh/java/io/rsocket/FragmentationPerf.java +++ /dev/null @@ -1,94 +0,0 @@ -package io.rsocket; - -import io.rsocket.fragmentation.FrameFragmenter; -import io.rsocket.fragmentation.FrameReassembler; -import io.rsocket.util.PayloadImpl; -import java.nio.ByteBuffer; -import java.util.concurrent.ThreadLocalRandom; -import java.util.stream.Collectors; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.Warmup; -import org.openjdk.jmh.infra.Blackhole; - -@BenchmarkMode(Mode.Throughput) -@Fork( - value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} -) -@Warmup(iterations = 10) -@Measurement(iterations = 10_000) -@State(Scope.Thread) -public class FragmentationPerf { - @State(Scope.Benchmark) - public static class Input { - Blackhole bh; - Frame smallFrame; - FrameFragmenter smallFrameFragmenter; - - Frame largeFrame; - FrameFragmenter largeFrameFragmenter; - - Iterable smallFramesIterable; - - @Setup - public void setup(Blackhole bh) { - this.bh = bh; - - ByteBuffer data = createRandomBytes(1 << 18); - ByteBuffer metadata = createRandomBytes(1 << 18); - largeFrame = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - largeFrameFragmenter = new FrameFragmenter(1024); - - data = createRandomBytes(16); - metadata = createRandomBytes(16); - smallFrame = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - smallFrameFragmenter = new FrameFragmenter(2); - smallFramesIterable = - smallFrameFragmenter - .fragment(smallFrame) - .map(Frame::copy) - .toStream() - .collect(Collectors.toList()); - } - } - - @Benchmark - public void smallFragmentationPerf(Input input) { - Frame frame = - input.smallFrameFragmenter.fragment(input.smallFrame).doOnNext(Frame::release).blockLast(); - input.bh.consume(frame); - } - - @Benchmark - public void largeFragmentationPerf(Input input) { - Frame frame = - input.largeFrameFragmenter.fragment(input.largeFrame).doOnNext(Frame::release).blockLast(); - input.bh.consume(frame); - } - - @Benchmark - public void smallFragmentationFrameReassembler(Input input) { - FrameReassembler smallFragmentAssembler = new FrameReassembler(input.smallFrame); - - input.smallFramesIterable.forEach(smallFragmentAssembler::append); - - Frame frame = smallFragmentAssembler.reassemble(); - input.bh.consume(frame); - frame.release(); - // input.smallFragmentAssembler.clear(); - } - - private static ByteBuffer createRandomBytes(int size) { - byte[] bytes = new byte[size]; - ThreadLocalRandom.current().nextBytes(bytes); - return ByteBuffer.wrap(bytes); - } -} diff --git a/rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java b/rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java deleted file mode 100644 index 215826e8e..000000000 --- a/rsocket-core/src/jmh/java/io/rsocket/RSocketPerf.java +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import io.rsocket.RSocketFactory.Start; -import io.rsocket.perfutil.TestDuplexConnection; -import io.rsocket.util.PayloadImpl; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Fork; -import org.openjdk.jmh.annotations.Measurement; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.annotations.Warmup; -import org.openjdk.jmh.infra.Blackhole; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.publisher.DirectProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -@BenchmarkMode(Mode.Throughput) -@Fork( - value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} -) -@Warmup(iterations = 10) -@Measurement(iterations = 10) -@State(Scope.Thread) -public class RSocketPerf { - - @Benchmark - public void requestResponseHello(Input input) { - try { - input.client.requestResponse(Input.HELLO_PAYLOAD).subscribe(input.blackHoleSubscriber); - } catch (Throwable t) { - t.printStackTrace(); - } - } - - @Benchmark - public void requestStreamHello1000(Input input) { - try { - input.client.requestStream(Input.HELLO_PAYLOAD).subscribe(input.blackHoleSubscriber); - } catch (Throwable t) { - t.printStackTrace(); - } - } - - @Benchmark - public void fireAndForgetHello(Input input) { - // this is synchronous so we don't need to use a CountdownLatch to wait - input.client.fireAndForget(Input.HELLO_PAYLOAD).subscribe(input.voidSubscriber); - } - - @State(Scope.Benchmark) - public static class Input { - /** Use to consume values when the test needs to return more than a single value. */ - public Blackhole bh; - - static final ByteBuffer HELLO = ByteBuffer.wrap("HELLO".getBytes(StandardCharsets.UTF_8)); - - static final Payload HELLO_PAYLOAD = new PayloadImpl(HELLO); - - static final DirectProcessor clientReceive = DirectProcessor.create(); - static final DirectProcessor serverReceive = DirectProcessor.create(); - - static final TestDuplexConnection clientConnection = - new TestDuplexConnection(serverReceive, clientReceive); - static final TestDuplexConnection serverConnection = - new TestDuplexConnection(clientReceive, serverReceive); - - static final Start server = - RSocketFactory.receive() - .acceptor( - (setup, sendingSocket) -> { - RSocket rSocket = - new RSocket() { - @Override - public Mono fireAndForget(Payload payload) { - return Mono.empty(); - } - - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(HELLO_PAYLOAD); - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1_000).flatMap(i -> requestResponse(payload)); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.empty(); - } - - @Override - public Mono metadataPush(Payload payload) { - return Mono.empty(); - } - - @Override - public Mono close() { - return Mono.empty(); - } - - @Override - public Mono onClose() { - return Mono.empty(); - } - }; - - return Mono.just(rSocket); - }) - .transport( - acceptor -> { - Closeable closeable = - new Closeable() { - MonoProcessor onClose = MonoProcessor.create(); - - @Override - public Mono close() { - return Mono.empty().doFinally(s -> onClose.onComplete()).then(); - } - - @Override - public Mono onClose() { - return onClose; - } - }; - - acceptor.apply(serverConnection).subscribe(); - - return Mono.just(closeable); - }); - - Subscriber blackHoleSubscriber; - Subscriber voidSubscriber; - - RSocket client; - - @Setup - public void setup(Blackhole bh) { - blackHoleSubscriber = subscriber(bh); - voidSubscriber = subscriber(bh); - - client = - RSocketFactory.connect().transport(() -> Mono.just(clientConnection)).start().block(); - - this.bh = bh; - } - - private Subscriber subscriber(Blackhole bh) { - return new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(T o) { - bh.consume(o); - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - } - - @Override - public void onComplete() {} - }; - } - } -} diff --git a/rsocket-core/src/jmh/java/io/rsocket/perfutil/TestDuplexConnection.java b/rsocket-core/src/jmh/java/io/rsocket/perfutil/TestDuplexConnection.java deleted file mode 100644 index 47db25d45..000000000 --- a/rsocket-core/src/jmh/java/io/rsocket/perfutil/TestDuplexConnection.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.perfutil; - -import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import org.reactivestreams.Publisher; -import reactor.core.publisher.DirectProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** - * An implementation of {@link DuplexConnection} that provides functionality to modify the behavior - * dynamically. - */ -public class TestDuplexConnection implements DuplexConnection { - - private final DirectProcessor send; - private final DirectProcessor receive; - - public TestDuplexConnection(DirectProcessor send, DirectProcessor receive) { - this.send = send; - this.receive = receive; - } - - @Override - public Mono send(Publisher frame) { - return Flux.from(frame) - .doOnNext( - f -> { - try { - send.onNext(f); - } finally { - f.release(); - } - }) - .then(); - } - - @Override - public Mono sendOne(Frame frame) { - send.onNext(frame); - return Mono.empty(); - } - - @Override - public Flux receive() { - return receive; - } - - @Override - public double availability() { - return 1.0; - } - - @Override - public Mono close() { - return Mono.empty(); - } - - @Override - public Mono onClose() { - return Mono.empty(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java b/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java deleted file mode 100644 index 3553309d8..000000000 --- a/rsocket-core/src/main/java/io/rsocket/AbstractRSocket.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -/** - * An abstract implementation of {@link RSocket}. All request handling methods emit {@link - * UnsupportedOperationException} and hence must be overridden to provide a valid implementation. - * - *

{@link #close()} and {@link #onClose()} returns a {@code Publisher} that never terminates. - */ -public abstract class AbstractRSocket implements RSocket { - - private final MonoProcessor onClose = MonoProcessor.create(); - - @Override - public Mono fireAndForget(Payload payload) { - return Mono.error(new UnsupportedOperationException("Fire and forget not implemented.")); - } - - @Override - public Mono requestResponse(Payload payload) { - return Mono.error(new UnsupportedOperationException("Request-Response not implemented.")); - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.error(new UnsupportedOperationException("Request-Stream not implemented.")); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.error(new UnsupportedOperationException("Request-Channel not implemented.")); - } - - @Override - public Mono metadataPush(Payload payload) { - return Mono.error(new UnsupportedOperationException("Metadata-Push not implemented.")); - } - - @Override - public Mono close() { - return Mono.defer( - () -> { - onClose.onComplete(); - return onClose; - }); - } - - @Override - public Mono onClose() { - return onClose; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/Availability.java b/rsocket-core/src/main/java/io/rsocket/Availability.java index b655f6d43..3361bcf8d 100644 --- a/rsocket-core/src/main/java/io/rsocket/Availability.java +++ b/rsocket-core/src/main/java/io/rsocket/Availability.java @@ -1,14 +1,17 @@ /* - * Copyright 2016 Netflix, Inc. - *

- * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - *

- * http://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket; diff --git a/rsocket-core/src/main/java/io/rsocket/Closeable.java b/rsocket-core/src/main/java/io/rsocket/Closeable.java index 0909de63d..2ea9a0371 100644 --- a/rsocket-core/src/main/java/io/rsocket/Closeable.java +++ b/rsocket-core/src/main/java/io/rsocket/Closeable.java @@ -1,41 +1,36 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket; +import org.reactivestreams.Subscriber; +import reactor.core.Disposable; import reactor.core.publisher.Mono; -/** */ -public interface Closeable { +/** An interface which allows listening to when a specific instance of this interface is closed */ +public interface Closeable extends Disposable { /** - * Close this {@code RSocket} upon subscribing to the returned {@code Publisher} + * Returns a {@link Mono} that terminates when the instance is terminated by any reason. Note, in + * case of error termination, the cause of error will be propagated as an error signal through + * {@link org.reactivestreams.Subscriber#onError(Throwable)}. Otherwise, {@link + * Subscriber#onComplete()} will be called. * - *

This method is idempotent and hence can be called as many times at any point with same - * outcome. - * - * @return A {@code Publisher} that completes when this {@code RSocket} close is complete. - */ - Mono close(); - - /** - * Returns a {@code Publisher} that completes when this {@code RSocket} is closed. A {@code - * RSocket} can be closed by explicitly calling {@link #close()} or when the underlying transport - * connection is closed. - * - * @return A {@code Publisher} that completes when this {@code RSocket} close is complete. + * @return a {@link Mono} to track completion with success or error of the underlying resource. + * When the underlying resource is an `RSocket`, the {@code Mono} exposes stream 0 (i.e. + * connection level) errors. */ Mono onClose(); } diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java index fa4f7ff71..c39e679a1 100644 --- a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,120 +13,47 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket; -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; +package io.rsocket; -import io.rsocket.frame.SetupFrameFlyweight; -import java.nio.ByteBuffer; +import io.netty.buffer.ByteBuf; +import io.netty.util.AbstractReferenceCounted; +import reactor.util.annotation.Nullable; /** - * Exposed to server for determination of RequestHandler based on mime types and SETUP metadata/data + * Exposes information from the {@code SETUP} frame to a server, as well as to client responders. */ -public abstract class ConnectionSetupPayload implements Payload { +public abstract class ConnectionSetupPayload extends AbstractReferenceCounted implements Payload { - public static final int NO_FLAGS = 0; - public static final int HONOR_LEASE = SetupFrameFlyweight.FLAGS_WILL_HONOR_LEASE; - public static final int STRICT_INTERPRETATION = SetupFrameFlyweight.FLAGS_STRICT_INTERPRETATION; + public abstract String metadataMimeType(); - public static ConnectionSetupPayload create(String metadataMimeType, String dataMimeType) { - return new ConnectionSetupPayloadImpl( - metadataMimeType, dataMimeType, Frame.NULL_BYTEBUFFER, Frame.NULL_BYTEBUFFER, NO_FLAGS); - } + public abstract String dataMimeType(); - public static ConnectionSetupPayload create( - String metadataMimeType, String dataMimeType, Payload payload) { - return new ConnectionSetupPayloadImpl( - metadataMimeType, - dataMimeType, - payload.getData(), - payload.getMetadata(), - payload.hasMetadata() ? FLAGS_M : 0); - } + public abstract int keepAliveInterval(); - public static ConnectionSetupPayload create( - String metadataMimeType, String dataMimeType, int flags) { - return new ConnectionSetupPayloadImpl( - metadataMimeType, dataMimeType, Frame.NULL_BYTEBUFFER, Frame.NULL_BYTEBUFFER, flags); - } + public abstract int keepAliveMaxLifetime(); - public static ConnectionSetupPayload create(final Frame setupFrame) { - Frame.ensureFrameType(FrameType.SETUP, setupFrame); - return new ConnectionSetupPayloadImpl( - Frame.Setup.metadataMimeType(setupFrame), - Frame.Setup.dataMimeType(setupFrame), - setupFrame.getData(), - setupFrame.getMetadata(), - Frame.Setup.getFlags(setupFrame)); - } + public abstract int getFlags(); - public abstract String metadataMimeType(); + public abstract boolean willClientHonorLease(); - public abstract String dataMimeType(); + public abstract boolean isResumeEnabled(); - public abstract int getFlags(); + @Nullable + public abstract ByteBuf resumeToken(); - public boolean willClientHonorLease() { - return Frame.isFlagSet(getFlags(), HONOR_LEASE); - } - - public boolean doesClientRequestStrictInterpretation() { - return STRICT_INTERPRETATION == (getFlags() & STRICT_INTERPRETATION); + @Override + public ConnectionSetupPayload retain() { + super.retain(); + return this; } @Override - public boolean hasMetadata() { - return Frame.isFlagSet(getFlags(), FLAGS_M); + public ConnectionSetupPayload retain(int increment) { + super.retain(increment); + return this; } - private static final class ConnectionSetupPayloadImpl extends ConnectionSetupPayload { - - private final String metadataMimeType; - private final String dataMimeType; - private final ByteBuffer data; - private final ByteBuffer metadata; - private final int flags; - - public ConnectionSetupPayloadImpl( - String metadataMimeType, - String dataMimeType, - ByteBuffer data, - ByteBuffer metadata, - int flags) { - this.metadataMimeType = metadataMimeType; - this.dataMimeType = dataMimeType; - this.data = data; - this.metadata = metadata; - this.flags = flags; - - if (!hasMetadata() && metadata.remaining() > 0) { - throw new IllegalArgumentException("metadata flag incorrect"); - } - } - - @Override - public String metadataMimeType() { - return metadataMimeType; - } - - @Override - public String dataMimeType() { - return dataMimeType; - } - - @Override - public ByteBuffer getData() { - return data; - } - - @Override - public ByteBuffer getMetadata() { - return metadata; - } - - @Override - public int getFlags() { - return flags; - } - } + @Override + public abstract ConnectionSetupPayload touch(); } diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java index 0dea9d9d2..fe91f4bf0 100644 --- a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,65 +13,81 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.net.SocketAddress; import java.nio.channels.ClosedChannelException; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** Represents a connection with input/output that the protocol uses. */ public interface DuplexConnection extends Availability, Closeable { /** - * Sends the source of {@link Frame}s on this connection and returns the {@code Publisher} - * representing the result of this send. - * - *

Flow control

+ * Delivers the given frame to the underlying transport connection. This method is non-blocking + * and can be safely executed from multiple threads. This method does not provide any flow-control + * mechanism. * - * The passed {@code Publisher} must - * - * @param frame Stream of {@code Frame}s to send on the connection. - * @return {@code Publisher} that completes when all the frames are written on the connection - * successfully and errors when it fails. + * @param streamId to which the given frame relates + * @param frame with the encoded content */ - Mono send(Publisher frame); + void sendFrame(int streamId, ByteBuf frame); /** - * Sends a single {@code Frame} on this connection and returns the {@code Publisher} representing - * the result of this send. + * Send an error frame and after it is successfully sent, close the connection. * - * @param frame {@code Frame} to send. - * @return {@code Publisher} that completes when the frame is written on the connection - * successfully and errors when it fails. + * @param errorException to encode in the error frame */ - default Mono sendOne(Frame frame) { - return send(Mono.just(frame)); - } + void sendErrorAndClose(RSocketErrorException errorException); /** * Returns a stream of all {@code Frame}s received on this connection. * - *

Completion

+ *

Completion * - * Returned {@code Publisher} MUST never emit a completion event ({@link - * Subscriber#onComplete()}. + *

Returned {@code Publisher} MUST never emit a completion event ({@link + * Subscriber#onComplete()}). * - *

Error

+ *

Error * - * Returned {@code Publisher} can error with various transport errors. If the underlying physical - * connection is closed by the peer, then the returned stream from here MUST emit an - * {@link ClosedChannelException}. + *

Returned {@code Publisher} can error with various transport errors. If the underlying + * physical connection is closed by the peer, then the returned stream from here MUST + * emit an {@link ClosedChannelException}. * - *

Multiple Subscriptions

+ *

Multiple Subscriptions * - * Returned {@code Publisher} is not required to support multiple concurrent subscriptions. + *

Returned {@code Publisher} is not required to support multiple concurrent subscriptions. * RSocket will never have multiple subscriptions to this source. Implementations MUST * emit an {@link IllegalStateException} for subsequent concurrent subscriptions, if they do not * support multiple concurrent subscriptions. * * @return Stream of all {@code Frame}s received. */ - Flux receive(); + Flux receive(); + + /** + * Returns the assigned {@link ByteBufAllocator}. + * + * @return the {@link ByteBufAllocator} + */ + ByteBufAllocator alloc(); + + /** + * Return the remote address that this connection is connected to. The returned {@link + * SocketAddress} varies by transport type and should be downcast to obtain more detailed + * information. For TCP and WebSocket, the address type is {@link java.net.InetSocketAddress}. For + * local transport, it is {@link io.rsocket.transport.local.LocalSocketAddress}. + * + * @return the address + * @since 1.1 + */ + SocketAddress remoteAddress(); + + @Override + default double availability() { + return isDisposed() ? 0.0 : 1.0; + } } diff --git a/rsocket-core/src/main/java/io/rsocket/Frame.java b/rsocket-core/src/main/java/io/rsocket/Frame.java deleted file mode 100644 index 1642e942d..000000000 --- a/rsocket-core/src/main/java/io/rsocket/Frame.java +++ /dev/null @@ -1,694 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket; - -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; - -import io.netty.buffer.*; -import io.netty.util.IllegalReferenceCountException; -import io.netty.util.Recycler; -import io.netty.util.Recycler.Handle; -import io.netty.util.ResourceLeakDetector; -import io.rsocket.frame.ErrorFrameFlyweight; -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.frame.KeepaliveFrameFlyweight; -import io.rsocket.frame.LeaseFrameFlyweight; -import io.rsocket.frame.RequestFrameFlyweight; -import io.rsocket.frame.RequestNFrameFlyweight; -import io.rsocket.frame.SetupFrameFlyweight; -import io.rsocket.frame.VersionFlyweight; -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import javax.annotation.Nullable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Represents a Frame sent over a {@link DuplexConnection}. - * - *

This provides encoding, decoding and field accessors. - */ -public class Frame implements ByteBufHolder { - public static final ByteBuffer NULL_BYTEBUFFER = ByteBuffer.allocateDirect(0); - - private static final Recycler RECYCLER = - new Recycler() { - protected Frame newObject(Handle handle) { - return new Frame(handle); - } - }; - - private final Handle handle; - private @Nullable ByteBuf content; - - private Frame(final Handle handle) { - this.handle = handle; - } - - /** Clear and recycle this instance. */ - private void recycle() { - content = null; - handle.recycle(this); - } - - /** Return the content which is held by this {@link Frame}. */ - @Override - public ByteBuf content() { - if (content.refCnt() <= 0) { - throw new IllegalReferenceCountException(content.refCnt()); - } - return content; - } - - /** Creates a deep copy of this {@link Frame}. */ - @Override - public Frame copy() { - return replace(content.copy()); - } - - /** - * Duplicates this {@link Frame}. Be aware that this will not automatically call {@link - * #retain()}. - */ - @Override - public Frame duplicate() { - return replace(content.duplicate()); - } - - /** - * Duplicates this {@link Frame}. This method returns a retained duplicate unlike {@link - * #duplicate()}. - * - * @see ByteBuf#retainedDuplicate() - */ - @Override - public Frame retainedDuplicate() { - return replace(content.retainedDuplicate()); - } - - /** Returns a new {@link Frame} which contains the specified {@code content}. */ - @Override - public Frame replace(ByteBuf content) { - return from(content); - } - - /** - * Returns the reference count of this object. If {@code 0}, it means this object has been - * deallocated. - */ - @Override - public int refCnt() { - return content.refCnt(); - } - - /** Increases the reference count by {@code 1}. */ - @Override - public Frame retain() { - content.retain(); - return this; - } - - /** Increases the reference count by the specified {@code increment}. */ - @Override - public Frame retain(int increment) { - content.retain(increment); - return this; - } - - /** - * Records the current access location of this object for debugging purposes. If this object is - * determined to be leaked, the information recorded by this operation will be provided to you via - * {@link ResourceLeakDetector}. This method is a shortcut to {@link #touch(Object) touch(null)}. - */ - @Override - public Frame touch() { - content.touch(); - return this; - } - - /** - * Records the current access location of this object with an additional arbitrary information for - * debugging purposes. If this object is determined to be leaked, the information recorded by this - * operation will be provided to you via {@link ResourceLeakDetector}. - */ - @Override - public Frame touch(@Nullable Object hint) { - content.touch(hint); - return this; - } - - /** - * Decreases the reference count by {@code 1} and deallocates this object if the reference count - * reaches at {@code 0}. - * - * @return {@code true} if and only if the reference count became {@code 0} and this object has - * been deallocated - */ - @Override - public boolean release() { - if (content.release()) { - recycle(); - return true; - } - return false; - } - - /** - * Decreases the reference count by the specified {@code decrement} and deallocates this object if - * the reference count reaches at {@code 0}. - * - * @return {@code true} if and only if the reference count became {@code 0} and this object has - * been deallocated - */ - @Override - public boolean release(int decrement) { - if (content.release(decrement)) { - recycle(); - return true; - } - return false; - } - - /** - * Return {@link ByteBuffer} that is a {@link ByteBuffer#slice()} for the frame metadata - * - *

If no metadata is present, the ByteBuffer will have 0 capacity. - * - * @return ByteBuffer containing the content - */ - public ByteBuffer getMetadata() { - final ByteBuf metadata = FrameHeaderFlyweight.sliceFrameMetadata(content); - if (metadata == null) { - return NULL_BYTEBUFFER; - } else if (metadata.readableBytes() > 0) { - final ByteBuffer buffer = ByteBuffer.allocateDirect(metadata.readableBytes()); - metadata.readBytes(buffer); - buffer.flip(); - return buffer; - } else { - return NULL_BYTEBUFFER; - } - } - - /** - * Return {@link ByteBuffer} that is a {@link ByteBuffer#slice()} for the frame data - * - *

If no data is present, the ByteBuffer will have 0 capacity. - * - * @return ByteBuffer containing the data - */ - public ByteBuffer getData() { - final ByteBuf data = FrameHeaderFlyweight.sliceFrameData(content); - if (data.readableBytes() > 0) { - final ByteBuffer buffer = ByteBuffer.allocateDirect(data.readableBytes()); - data.readBytes(buffer); - buffer.flip(); - return buffer; - } else { - return NULL_BYTEBUFFER; - } - } - - /** - * Return frame stream identifier - * - * @return frame stream identifier - */ - public int getStreamId() { - return FrameHeaderFlyweight.streamId(content); - } - - /** - * Return frame {@link FrameType} - * - * @return frame type - */ - public FrameType getType() { - return FrameHeaderFlyweight.frameType(content); - } - - /** - * Return the flags field for the frame - * - * @return frame flags field value - */ - public int flags() { - return FrameHeaderFlyweight.flags(content); - } - - /** - * Acquire a free Frame backed by given ByteBuf - * - * @param content to use as backing buffer - * @return frame - */ - public static Frame from(final ByteBuf content) { - final Frame frame = RECYCLER.get(); - frame.content = content; - - return frame; - } - - public static boolean isFlagSet(int flags, int checkedFlag) { - return (flags & checkedFlag) == checkedFlag; - } - - public static int setFlag(int current, int toSet) { - return current | toSet; - } - - public boolean hasMetadata() { - return Frame.isFlagSet(this.flags(), FLAGS_M); - } - - public String getDataUtf8() { - return StandardCharsets.UTF_8.decode(getData()).toString(); - } - - /* TODO: - * - * fromRequest(type, id, payload) - * fromKeepalive(ByteBuf content) - * - */ - - // SETUP specific getters - public static class Setup { - - private Setup() {} - - public static Frame from( - int flags, - int keepaliveInterval, - int maxLifetime, - String metadataMimeType, - String dataMimeType, - Payload payload) { - final ByteBuf metadata = - payload.hasMetadata() - ? Unpooled.wrappedBuffer(payload.getMetadata()) - : Unpooled.EMPTY_BUFFER; - final ByteBuf data = - payload.getData() != null - ? Unpooled.wrappedBuffer(payload.getData()) - : Unpooled.EMPTY_BUFFER; - - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - SetupFrameFlyweight.computeFrameLength( - flags, - metadataMimeType, - dataMimeType, - metadata.readableBytes(), - data.readableBytes())); - frame.content.writerIndex( - SetupFrameFlyweight.encode( - frame.content, - flags, - keepaliveInterval, - maxLifetime, - metadataMimeType, - dataMimeType, - metadata, - data)); - return frame; - } - - public static int getFlags(final Frame frame) { - ensureFrameType(FrameType.SETUP, frame); - final int flags = FrameHeaderFlyweight.flags(frame.content); - - return flags & SetupFrameFlyweight.VALID_FLAGS; - } - - public static int version(final Frame frame) { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.version(frame.content); - } - - public static int keepaliveInterval(final Frame frame) { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.keepaliveInterval(frame.content); - } - - public static int maxLifetime(final Frame frame) { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.maxLifetime(frame.content); - } - - public static String metadataMimeType(final Frame frame) { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.metadataMimeType(frame.content); - } - - public static String dataMimeType(final Frame frame) { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.dataMimeType(frame.content); - } - } - - public static class Error { - private static final Logger errorLogger = LoggerFactory.getLogger(Error.class); - - private Error() {} - - public static Frame from(int streamId, final Throwable throwable, ByteBuf dataBuffer) { - if (errorLogger.isDebugEnabled()) { - errorLogger.debug("an error occurred, creating error frame", throwable); - } - - final int code = ErrorFrameFlyweight.errorCodeFromException(throwable); - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - ErrorFrameFlyweight.computeFrameLength(dataBuffer.readableBytes())); - frame.content.writerIndex( - ErrorFrameFlyweight.encode(frame.content, streamId, code, dataBuffer)); - return frame; - } - - public static Frame from(int streamId, final Throwable throwable) { - String data = throwable.getMessage() == null ? "" : throwable.getMessage(); - byte[] bytes = data.getBytes(StandardCharsets.UTF_8); - - return from(streamId, throwable, Unpooled.wrappedBuffer(bytes)); - } - - public static int errorCode(final Frame frame) { - ensureFrameType(FrameType.ERROR, frame); - return ErrorFrameFlyweight.errorCode(frame.content); - } - - public static String message(Frame frame) { - ensureFrameType(FrameType.ERROR, frame); - return ErrorFrameFlyweight.message(frame.content); - } - } - - public static class Lease { - private Lease() {} - - public static Frame from(int ttl, int numberOfRequests, ByteBuf metadata) { - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - LeaseFrameFlyweight.computeFrameLength(metadata.readableBytes())); - frame.content.writerIndex( - LeaseFrameFlyweight.encode(frame.content, ttl, numberOfRequests, metadata)); - return frame; - } - - public static int ttl(final Frame frame) { - ensureFrameType(FrameType.LEASE, frame); - return LeaseFrameFlyweight.ttl(frame.content); - } - - public static int numberOfRequests(final Frame frame) { - ensureFrameType(FrameType.LEASE, frame); - return LeaseFrameFlyweight.numRequests(frame.content); - } - } - - public static class RequestN { - private RequestN() {} - - public static Frame from(int streamId, long requestN) { - int v = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; - return from(streamId, v); - } - - public static Frame from(int streamId, int requestN) { - if (requestN < 1) { - throw new IllegalStateException("request n must be greater than 0"); - } - - final Frame frame = RECYCLER.get(); - frame.content = ByteBufAllocator.DEFAULT.buffer(RequestNFrameFlyweight.computeFrameLength()); - frame.content.writerIndex(RequestNFrameFlyweight.encode(frame.content, streamId, requestN)); - return frame; - } - - public static int requestN(final Frame frame) { - ensureFrameType(FrameType.REQUEST_N, frame); - return RequestNFrameFlyweight.requestN(frame.content); - } - } - - public static class Request { - private Request() {} - - public static Frame from(int streamId, FrameType type, Payload payload, long initialRequestN) { - int v = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; - return from(streamId, type, payload, v); - } - - public static Frame from(int streamId, FrameType type, Payload payload, int initialRequestN) { - if (initialRequestN < 1) { - throw new IllegalStateException("initial request n must be greater than 0"); - } - final @Nullable ByteBuf metadata = - payload.hasMetadata() ? Unpooled.wrappedBuffer(payload.getMetadata()) : null; - final ByteBuf data = Unpooled.wrappedBuffer(payload.getData()); - - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - RequestFrameFlyweight.computeFrameLength( - type, metadata != null ? metadata.readableBytes() : null, data.readableBytes())); - - if (type.hasInitialRequestN()) { - frame.content.writerIndex( - RequestFrameFlyweight.encode( - frame.content, - streamId, - metadata != null ? FLAGS_M : 0, - type, - initialRequestN, - metadata, - data)); - } else { - frame.content.writerIndex( - RequestFrameFlyweight.encode( - frame.content, streamId, metadata != null ? FLAGS_M : 0, type, metadata, data)); - } - - return frame; - } - - public static Frame from(int streamId, FrameType type, int flags) { - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer(RequestFrameFlyweight.computeFrameLength(type, null, 0)); - frame.content.writerIndex( - RequestFrameFlyweight.encode( - frame.content, streamId, flags, type, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER)); - return frame; - } - - public static Frame from( - int streamId, - FrameType type, - ByteBuf metadata, - ByteBuf data, - int initialRequestN, - int flags) { - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - RequestFrameFlyweight.computeFrameLength( - type, metadata.readableBytes(), data.readableBytes())); - frame.content.writerIndex( - RequestFrameFlyweight.encode( - frame.content, streamId, flags, type, initialRequestN, metadata, data)); - return frame; - } - - public static int initialRequestN(final Frame frame) { - final FrameType type = frame.getType(); - int result; - - if (!type.isRequestType()) { - throw new AssertionError("expected request type, but saw " + type.name()); - } - - switch (frame.getType()) { - case REQUEST_RESPONSE: - result = 1; - break; - case FIRE_AND_FORGET: - result = 0; - break; - default: - result = RequestFrameFlyweight.initialRequestN(frame.content); - break; - } - - return result; - } - - public static boolean isRequestChannelComplete(final Frame frame) { - ensureFrameType(FrameType.REQUEST_CHANNEL, frame); - final int flags = FrameHeaderFlyweight.flags(frame.content); - - return (flags & FrameHeaderFlyweight.FLAGS_C) == FrameHeaderFlyweight.FLAGS_C; - } - } - - public static class PayloadFrame { - - private PayloadFrame() {} - - public static Frame from(int streamId, FrameType type) { - return from(streamId, type, null, Unpooled.EMPTY_BUFFER, 0); - } - - public static Frame from(int streamId, FrameType type, Payload payload) { - return from(streamId, type, payload, payload.hasMetadata() ? FLAGS_M : 0); - } - - public static Frame from(int streamId, FrameType type, Payload payload, int flags) { - final ByteBuf metadata = - payload.hasMetadata() ? Unpooled.wrappedBuffer(payload.getMetadata()) : null; - final ByteBuf data = Unpooled.wrappedBuffer(payload.getData()); - return from(streamId, type, metadata, data, flags); - } - - public static Frame from( - int streamId, FrameType type, @Nullable ByteBuf metadata, ByteBuf data, int flags) { - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - FrameHeaderFlyweight.computeFrameHeaderLength( - type, metadata != null ? metadata.readableBytes() : null, data.readableBytes())); - frame.content.writerIndex( - FrameHeaderFlyweight.encode(frame.content, streamId, flags, type, metadata, data)); - return frame; - } - } - - public static class Cancel { - - private Cancel() {} - - public static Frame from(int streamId) { - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.CANCEL, null, 0)); - frame.content.writerIndex( - FrameHeaderFlyweight.encode( - frame.content, streamId, 0, FrameType.CANCEL, null, Unpooled.EMPTY_BUFFER)); - return frame; - } - } - - public static class Keepalive { - - private Keepalive() {} - - public static Frame from(ByteBuf data, boolean respond) { - final Frame frame = RECYCLER.get(); - frame.content = - ByteBufAllocator.DEFAULT.buffer( - KeepaliveFrameFlyweight.computeFrameLength(data.readableBytes())); - - final int flags = respond ? KeepaliveFrameFlyweight.FLAGS_KEEPALIVE_R : 0; - frame.content.writerIndex(KeepaliveFrameFlyweight.encode(frame.content, flags, data)); - - return frame; - } - - public static boolean hasRespondFlag(final Frame frame) { - ensureFrameType(FrameType.KEEPALIVE, frame); - final int flags = FrameHeaderFlyweight.flags(frame.content); - - return (flags & KeepaliveFrameFlyweight.FLAGS_KEEPALIVE_R) - == KeepaliveFrameFlyweight.FLAGS_KEEPALIVE_R; - } - } - - public static void ensureFrameType(final FrameType frameType, final Frame frame) { - final FrameType typeInFrame = frame.getType(); - - if (typeInFrame != frameType) { - throw new AssertionError("expected " + frameType + ", but saw" + typeInFrame); - } - } - - @Override - public String toString() { - FrameType type = FrameHeaderFlyweight.frameType(content); - StringBuilder payload = new StringBuilder(); - @Nullable ByteBuf metadata = FrameHeaderFlyweight.sliceFrameMetadata(content); - - if (metadata != null) { - if (0 < metadata.readableBytes()) { - payload.append( - String.format("metadata: \"%s\" ", metadata.toString(StandardCharsets.UTF_8))); - } - } - - ByteBuf data = FrameHeaderFlyweight.sliceFrameData(content); - if (0 < data.readableBytes()) { - payload.append(String.format("data: \"%s\" ", data.toString(StandardCharsets.UTF_8))); - } - - long streamId = FrameHeaderFlyweight.streamId(content); - - String additionalFlags = ""; - switch (type) { - case LEASE: - additionalFlags = " Permits: " + Lease.numberOfRequests(this) + " TTL: " + Lease.ttl(this); - break; - case REQUEST_N: - additionalFlags = " RequestN: " + RequestN.requestN(this); - break; - case KEEPALIVE: - additionalFlags = " Respond flag: " + Keepalive.hasRespondFlag(this); - break; - case REQUEST_STREAM: - case REQUEST_CHANNEL: - additionalFlags = " Initial Request N: " + Request.initialRequestN(this); - break; - case ERROR: - additionalFlags = " Error code: " + Error.errorCode(this); - break; - case SETUP: - int version = Setup.version(this); - additionalFlags = - " Version: " - + VersionFlyweight.toString(version) - + " keep-alive interval: " - + Setup.keepaliveInterval(this) - + " max lifetime: " - + Setup.maxLifetime(this) - + " metadata mime type: " - + Setup.metadataMimeType(this) - + " data mime type: " - + Setup.dataMimeType(this); - break; - } - - return "Frame => Stream ID: " - + streamId - + " Type: " - + type - + additionalFlags - + " Payload: " - + payload; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/FrameType.java b/rsocket-core/src/main/java/io/rsocket/FrameType.java deleted file mode 100644 index 9f1d6d740..000000000 --- a/rsocket-core/src/main/java/io/rsocket/FrameType.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket; - -/** Types of {@link Frame} that can be sent. */ -public enum FrameType { - // blank type that is not defined - UNDEFINED(0x00), - // Connection - SETUP(0x01, Flags.CAN_HAVE_METADATA_AND_DATA), - LEASE(0x02, Flags.CAN_HAVE_METADATA), - KEEPALIVE(0x03, Flags.CAN_HAVE_DATA), - // Requester to start request - REQUEST_RESPONSE(0x04, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE), - FIRE_AND_FORGET(0x05, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE), - REQUEST_STREAM( - 0x06, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE | Flags.HAS_INITIAL_REQUEST_N), - REQUEST_CHANNEL( - 0x07, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE | Flags.HAS_INITIAL_REQUEST_N), - // Requester mid-stream - REQUEST_N(0x08), - CANCEL(0x09, Flags.CAN_HAVE_METADATA), - // Responder - PAYLOAD(0x0A, Flags.CAN_HAVE_METADATA_AND_DATA), - ERROR(0x0B, Flags.CAN_HAVE_METADATA_AND_DATA), - // Requester & Responder - METADATA_PUSH(0x0C, Flags.CAN_HAVE_METADATA), - // Resumption frames, not yet implemented - RESUME(0x0D), - RESUME_OK(0x0E), - // synthetic types from Responder for use by the rest of the machinery - NEXT(0xA0, Flags.CAN_HAVE_METADATA_AND_DATA), - COMPLETE(0xB0), - NEXT_COMPLETE(0xC0, Flags.CAN_HAVE_METADATA_AND_DATA), - EXT(0xFFFF, Flags.CAN_HAVE_METADATA_AND_DATA); - - private static class Flags { - private Flags() {} - - private static final int CAN_HAVE_DATA = 0b0001; - private static final int CAN_HAVE_METADATA = 0b0010; - private static final int CAN_HAVE_METADATA_AND_DATA = 0b0011; - private static final int IS_REQUEST_TYPE = 0b0100; - private static final int HAS_INITIAL_REQUEST_N = 0b1000; - } - - private static FrameType[] typesById; - - private final int id; - private final int flags; - - /* Index types by id for indexed lookup. */ - static { - int max = 0; - - for (FrameType t : values()) { - max = Math.max(t.id, max); - } - - typesById = new FrameType[max + 1]; - - for (FrameType t : values()) { - typesById[t.id] = t; - } - } - - FrameType(final int id) { - this(id, 0); - } - - FrameType(int id, int flags) { - this.id = id; - this.flags = flags; - } - - public int getEncodedType() { - return id; - } - - public boolean isRequestType() { - return Flags.IS_REQUEST_TYPE == (flags & Flags.IS_REQUEST_TYPE); - } - - public boolean hasInitialRequestN() { - return Flags.HAS_INITIAL_REQUEST_N == (flags & Flags.HAS_INITIAL_REQUEST_N); - } - - public boolean canHaveData() { - return Flags.CAN_HAVE_DATA == (flags & Flags.CAN_HAVE_DATA); - } - - public boolean canHaveMetadata() { - return Flags.CAN_HAVE_METADATA == (flags & Flags.CAN_HAVE_METADATA); - } - - // TODO: offset of metadata and data (simplify parsing) naming: endOfFrameHeaderOffset() - public int payloadOffset() { - return 0; - } - - public static FrameType from(int id) { - return typesById[id]; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/Payload.java b/rsocket-core/src/main/java/io/rsocket/Payload.java index 00e0e360b..fc130528e 100644 --- a/rsocket-core/src/main/java/io/rsocket/Payload.java +++ b/rsocket-core/src/main/java/io/rsocket/Payload.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket; +import io.netty.buffer.ByteBuf; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -/** Payload of a {@link Frame}. */ -public interface Payload { +/** Payload of a Frame . */ +public interface Payload extends ReferenceCounted { /** * Returns whether the payload has metadata, useful for tell if metadata is empty or not present. * @@ -28,25 +32,73 @@ public interface Payload { boolean hasMetadata(); /** - * Returns the Payload metadata. Always non-null, check {@link #hasMetadata()} to differentiate - * null from "". + * Returns a slice Payload metadata. Always non-null, check {@link #hasMetadata()} to + * differentiate null from "". * * @return payload metadata. */ - ByteBuffer getMetadata(); + ByteBuf sliceMetadata(); /** * Returns the Payload data. Always non-null. * * @return payload data. */ - ByteBuffer getData(); + ByteBuf sliceData(); + + /** + * Returns the Payloads' data without slicing if possible. This is not safe and editing this could + * effect the payload. It is recommended to call sliceData(). + * + * @return data as a bytebuf or slice of the data + */ + ByteBuf data(); + + /** + * Returns the Payloads' metadata without slicing if possible. This is not safe and editing this + * could effect the payload. It is recommended to call sliceMetadata(). + * + * @return metadata as a bytebuf or slice of the metadata + */ + ByteBuf metadata(); + + /** Increases the reference count by {@code 1}. */ + @Override + Payload retain(); + + /** Increases the reference count by the specified {@code increment}. */ + @Override + Payload retain(int increment); + + /** + * Records the current access location of this object for debugging purposes. If this object is + * determined to be leaked, the information recorded by this operation will be provided to you via + * {@link ResourceLeakDetector}. This method is a shortcut to {@link #touch(Object) touch(null)}. + */ + @Override + Payload touch(); + + /** + * Records the current access location of this object with an additional arbitrary information for + * debugging purposes. If this object is determined to be leaked, the information recorded by this + * operation will be provided to you via {@link ResourceLeakDetector}. + */ + @Override + Payload touch(Object hint); + + default ByteBuffer getMetadata() { + return sliceMetadata().nioBuffer(); + } + + default ByteBuffer getData() { + return sliceData().nioBuffer(); + } default String getMetadataUtf8() { - return StandardCharsets.UTF_8.decode(getMetadata()).toString(); + return sliceMetadata().toString(StandardCharsets.UTF_8); } default String getDataUtf8() { - return StandardCharsets.UTF_8.decode(getData()).toString(); + return sliceData().toString(StandardCharsets.UTF_8); } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocket.java b/rsocket-core/src/main/java/io/rsocket/RSocket.java index 0f006d5ca..b05241365 100644 --- a/rsocket-core/src/main/java/io/rsocket/RSocket.java +++ b/rsocket-core/src/main/java/io/rsocket/RSocket.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -33,7 +33,9 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} that completes when the passed {@code payload} is successfully * handled, otherwise errors. */ - Mono fireAndForget(Payload payload); + default Mono fireAndForget(Payload payload) { + return RSocketAdapter.fireAndForget(payload); + } /** * Request-Response interaction model of {@code RSocket}. @@ -42,7 +44,9 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} containing at most a single {@code Payload} representing the * response. */ - Mono requestResponse(Payload payload); + default Mono requestResponse(Payload payload) { + return RSocketAdapter.requestResponse(payload); + } /** * Request-Stream interaction model of {@code RSocket}. @@ -50,7 +54,9 @@ public interface RSocket extends Availability, Closeable { * @param payload Request payload. * @return {@code Publisher} containing the stream of {@code Payload}s representing the response. */ - Flux requestStream(Payload payload); + default Flux requestStream(Payload payload) { + return RSocketAdapter.requestStream(payload); + } /** * Request-Channel interaction model of {@code RSocket}. @@ -58,7 +64,9 @@ public interface RSocket extends Availability, Closeable { * @param payloads Stream of request payloads. * @return Stream of response payloads. */ - Flux requestChannel(Publisher payloads); + default Flux requestChannel(Publisher payloads) { + return RSocketAdapter.requestChannel(payloads); + } /** * Metadata-Push interaction model of {@code RSocket}. @@ -67,10 +75,25 @@ public interface RSocket extends Availability, Closeable { * @return {@code Publisher} that completes when the passed {@code payload} is successfully * handled, otherwise errors. */ - Mono metadataPush(Payload payload); + default Mono metadataPush(Payload payload) { + return RSocketAdapter.metadataPush(payload); + } @Override default double availability() { - return 0.0; + return isDisposed() ? 0.0 : 1.0; + } + + @Override + default void dispose() {} + + @Override + default boolean isDisposed() { + return false; + } + + @Override + default Mono onClose() { + return Mono.never(); } } diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java new file mode 100644 index 000000000..b5a64b8dd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Package private class with default implementations for use in {@link RSocket}. The main purpose + * is to hide static {@link UnsupportedOperationException} declarations. + * + * @since 1.0.3 + */ +class RSocketAdapter { + + private static final Mono UNSUPPORTED_FIRE_AND_FORGET = + Mono.error(new UnsupportedInteractionException("Fire-and-Forget")); + + private static final Mono UNSUPPORTED_REQUEST_RESPONSE = + Mono.error(new UnsupportedInteractionException("Request-Response")); + + private static final Flux UNSUPPORTED_REQUEST_STREAM = + Flux.error(new UnsupportedInteractionException("Request-Stream")); + + private static final Flux UNSUPPORTED_REQUEST_CHANNEL = + Flux.error(new UnsupportedInteractionException("Request-Channel")); + + private static final Mono UNSUPPORTED_METADATA_PUSH = + Mono.error(new UnsupportedInteractionException("Metadata-Push")); + + static Mono fireAndForget(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_FIRE_AND_FORGET; + } + + static Mono requestResponse(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_RESPONSE; + } + + static Flux requestStream(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_STREAM; + } + + static Flux requestChannel(Publisher payloads) { + return RSocketAdapter.UNSUPPORTED_REQUEST_CHANNEL; + } + + static Mono metadataPush(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_METADATA_PUSH; + } + + private static class UnsupportedInteractionException extends RuntimeException { + + private static final long serialVersionUID = 5084623297446471999L; + + UnsupportedInteractionException(String interactionName) { + super(interactionName + " not implemented.", null, false, false); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/RSocketClient.java deleted file mode 100644 index a1017e62c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketClient.java +++ /dev/null @@ -1,581 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import io.netty.buffer.Unpooled; -import io.netty.util.collection.IntObjectHashMap; -import io.rsocket.exceptions.ConnectionException; -import io.rsocket.exceptions.Exceptions; -import io.rsocket.internal.LimitableRequestPublisher; -import io.rsocket.util.PayloadImpl; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import reactor.core.Disposable; -import reactor.core.publisher.*; - -import javax.annotation.Nullable; -import java.nio.channels.ClosedChannelException; -import java.time.Duration; -import java.util.Collection; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; - -import static io.rsocket.util.ExceptionUtil.noStacktrace; - -/** Client Side of a RSocket socket. Sends {@link Frame}s to a {@link RSocketServer} */ -class RSocketClient implements RSocket { - - private static final ClosedChannelException CLOSED_CHANNEL_EXCEPTION = - noStacktrace(new ClosedChannelException()); - - private final DuplexConnection connection; - private final Consumer errorConsumer; - private final StreamIdSupplier streamIdSupplier; - private final MonoProcessor started; - private final IntObjectHashMap senders; - private final IntObjectHashMap> receivers; - private final AtomicInteger missedAckCounter; - - private final EmitterProcessor sendProcessor; - - private @Nullable Disposable keepAliveSendSub; - private volatile long timeLastTickSentMs; - - RSocketClient( - DuplexConnection connection, - Consumer errorConsumer, - StreamIdSupplier streamIdSupplier) { - this(connection, errorConsumer, streamIdSupplier, Duration.ZERO, Duration.ZERO, 0); - } - - RSocketClient( - DuplexConnection connection, - Consumer errorConsumer, - StreamIdSupplier streamIdSupplier, - Duration tickPeriod, - Duration ackTimeout, - int missedAcks) { - this.connection = connection; - this.errorConsumer = errorConsumer; - this.streamIdSupplier = streamIdSupplier; - this.started = MonoProcessor.create(); - this.senders = new IntObjectHashMap<>(256, 0.9f); - this.receivers = new IntObjectHashMap<>(256, 0.9f); - this.missedAckCounter = new AtomicInteger(); - - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - // connections - this.sendProcessor = EmitterProcessor.create(); - - if (!Duration.ZERO.equals(tickPeriod)) { - long ackTimeoutMs = ackTimeout.toMillis(); - - this.keepAliveSendSub = - started - .thenMany(Flux.interval(tickPeriod)) - .doOnSubscribe(s -> timeLastTickSentMs = System.currentTimeMillis()) - .flatMap(i -> sendKeepAlive(ackTimeoutMs, missedAcks)) - .doOnError( - t -> { - errorConsumer.accept(t); - connection.close().subscribe(); - }) - .subscribe(); - } - - connection.onClose().doFinally(signalType -> cleanup()).doOnError(errorConsumer).subscribe(); - - connection - .send(sendProcessor) - .doOnError(this::handleSendProcessorError) - .doFinally(this::handleSendProcessorCancel) - .subscribe(); - - connection - .receive() - .doOnSubscribe(subscription -> started.onComplete()) - .doOnNext(this::handleIncomingFrames) - .doOnError(errorConsumer) - .subscribe(); - } - - private void handleSendProcessorError(Throwable t) { - Collection> values; - Collection values1; - synchronized (RSocketClient.this) { - values = receivers.values(); - values1 = senders.values(); - } - - for (Subscriber subscriber : values) { - try { - subscriber.onError(t); - } catch (Throwable e) { - errorConsumer.accept(e); - } - } - - for (LimitableRequestPublisher p : values1) { - p.cancel(); - } - } - - private void handleSendProcessorCancel(SignalType t) { - if (SignalType.ON_ERROR == t) { - return; - } - Collection> values; - Collection values1; - synchronized (RSocketClient.this) { - values = receivers.values(); - values1 = senders.values(); - } - - for (Subscriber subscriber : values) { - try { - subscriber.onError(new Throwable("closed connection")); - } catch (Throwable e) { - errorConsumer.accept(e); - } - } - - for (LimitableRequestPublisher p : values1) { - p.cancel(); - } - } - - private Mono sendKeepAlive(long ackTimeoutMs, int missedAcks) { - return Mono.fromRunnable( - () -> { - long now = System.currentTimeMillis(); - if (now - timeLastTickSentMs > ackTimeoutMs) { - int count = missedAckCounter.incrementAndGet(); - if (count >= missedAcks) { - String message = - String.format( - "Missed %d keep-alive acks with a threshold of %d and a ack timeout of %d ms", - count, missedAcks, ackTimeoutMs); - throw new ConnectionException(message); - } - } - - sendProcessor.onNext(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)); - }); - } - - @Override - public Mono fireAndForget(Payload payload) { - Mono defer = - Mono.fromRunnable( - () -> { - final int streamId = streamIdSupplier.nextStreamId(); - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.FIRE_AND_FORGET, payload, 1); - sendProcessor.onNext(requestFrame); - }); - - return started.then(defer); - } - - @Override - public Mono requestResponse(Payload payload) { - return handleRequestResponse(payload); - } - - @Override - public Flux requestStream(Payload payload) { - return handleRequestStream(payload); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return handleStreamResponse(Flux.from(payloads), FrameType.REQUEST_CHANNEL); - } - - @Override - public Mono metadataPush(Payload payload) { - final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1); - sendProcessor.onNext(requestFrame); - return Mono.empty(); - } - - @Override - public double availability() { - return connection.availability(); - } - - @Override - public Mono close() { - return connection.close(); - } - - @Override - public Mono onClose() { - return connection.onClose(); - } - - public Flux handleRequestStream(final Payload payload) { - return started.thenMany( - Flux.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - - UnicastProcessor receiver = UnicastProcessor.create(); - - synchronized (this) { - receivers.put(streamId, receiver); - } - - AtomicBoolean first = new AtomicBoolean(false); - - return receiver - .doOnRequest( - l -> { - if (first.compareAndSet(false, true) && !receiver.isTerminated()) { - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.REQUEST_STREAM, payload, l); - - sendProcessor.onNext(requestFrame); - } else if (contains(streamId) - && connection.availability() > 0.0 - && !receiver.isTerminated()) { - sendProcessor.onNext(Frame.RequestN.from(streamId, l)); - } - }) - .doOnError( - t -> { - if (contains(streamId) - && connection.availability() > 0.0 - && !receiver.isTerminated()) { - sendProcessor.onNext(Frame.Error.from(streamId, t)); - } - }) - .doOnCancel( - () -> { - if (contains(streamId) - && connection.availability() > 0.0 - && !receiver.isTerminated()) { - sendProcessor.onNext(Frame.Cancel.from(streamId)); - } - }) - .doFinally(s -> removeReceiver(streamId)); - })); - } - - private Mono handleRequestResponse(final Payload payload) { - return started.then( - Mono.defer( - () -> { - int streamId = streamIdSupplier.nextStreamId(); - final Frame requestFrame = - Frame.Request.from(streamId, FrameType.REQUEST_RESPONSE, payload, 1); - - MonoProcessor receiver = MonoProcessor.create(); - - synchronized (this) { - receivers.put(streamId, receiver); - } - - sendProcessor.onNext(requestFrame); - - return receiver - .doOnError(t -> sendProcessor.onNext(Frame.Error.from(streamId, t))) - .doOnCancel(() -> sendProcessor.onNext(Frame.Cancel.from(streamId))) - .doFinally(s -> removeReceiver(streamId)); - })); - } - - private Flux handleStreamResponse(Flux request, FrameType requestType) { - return started.thenMany( - Flux.defer( - new Supplier>() { - final UnicastProcessor receiver = UnicastProcessor.create(); - final int streamId = streamIdSupplier.nextStreamId(); - volatile @Nullable MonoProcessor subscribedRequests; - boolean firstRequest = true; - - boolean isValidToSendFrame() { - return contains(streamId) - && connection.availability() > 0.0 - && !receiver.isTerminated(); - } - - void sendOneFrame(Frame frame) { - if (isValidToSendFrame()) { - sendProcessor.onNext(frame); - } - } - - @Override - public Flux get() { - return receiver - .doOnRequest( - l -> { - boolean _firstRequest = false; - synchronized (RSocketClient.this) { - if (firstRequest) { - _firstRequest = true; - firstRequest = false; - } - } - - if (_firstRequest) { - Flux requestFrames = - request - .transform( - f -> { - LimitableRequestPublisher wrapped = - LimitableRequestPublisher.wrap(f); - // Need to set this to one for first the frame - wrapped.increaseRequestLimit(1); - synchronized (RSocketClient.this) { - senders.put(streamId, wrapped); - receivers.put(streamId, receiver); - } - - return wrapped; - }) - .map( - new Function() { - boolean firstPayload = true; - - @Override - public Frame apply(Payload payload) { - boolean _firstPayload = false; - synchronized (this) { - if (firstPayload) { - firstPayload = false; - _firstPayload = true; - } - } - - if (_firstPayload) { - return Frame.Request.from( - streamId, requestType, payload, l); - } else { - return Frame.PayloadFrame.from( - streamId, FrameType.NEXT, payload); - } - } - }) - .doOnComplete( - () -> { - if (FrameType.REQUEST_CHANNEL == requestType) { - sendOneFrame( - Frame.PayloadFrame.from( - streamId, FrameType.COMPLETE)); - } - }); - - requestFrames - .doOnNext(sendProcessor::onNext) - .doOnError( - t -> { - errorConsumer.accept(t); - receiver.cancel(); - }) - .subscribe(); - } else { - sendOneFrame(Frame.RequestN.from(streamId, l)); - } - }) - .doOnError(t -> sendOneFrame(Frame.Error.from(streamId, t))) - .doOnCancel( - () -> { - sendOneFrame(Frame.Cancel.from(streamId)); - if (subscribedRequests != null) { - subscribedRequests.cancel(); - } - }) - .doFinally( - s -> { - removeReceiver(streamId); - removeSender(streamId); - }); - } - })); - } - - private boolean contains(int streamId) { - synchronized (RSocketClient.this) { - return receivers.containsKey(streamId); - } - } - - protected void cleanup() { - Collection> subscribers; - Collection publishers; - synchronized (RSocketClient.this) { - subscribers = receivers.values(); - publishers = senders.values(); - - senders.clear(); - receivers.clear(); - } - - subscribers.forEach(this::cleanUpSubscriber); - publishers.forEach(this::cleanUpLimitableRequestPublisher); - - if (null != keepAliveSendSub) { - keepAliveSendSub.dispose(); - } - } - - private synchronized void cleanUpLimitableRequestPublisher( - LimitableRequestPublisher limitableRequestPublisher) { - try { - limitableRequestPublisher.cancel(); - } catch (Throwable t) { - errorConsumer.accept(t); - } - } - - private synchronized void cleanUpSubscriber(Subscriber subscriber) { - try { - subscriber.onError(CLOSED_CHANNEL_EXCEPTION); - } catch (Throwable t) { - errorConsumer.accept(t); - } - } - - private void handleIncomingFrames(Frame frame) { - try { - int streamId = frame.getStreamId(); - FrameType type = frame.getType(); - if (streamId == 0) { - handleStreamZero(type, frame); - } else { - handleFrame(streamId, type, frame); - } - } finally { - frame.release(); - } - } - - private void handleStreamZero(FrameType type, Frame frame) { - switch (type) { - case ERROR: - throw Exceptions.from(frame); - case LEASE: - { - break; - } - case KEEPALIVE: - if (!Frame.Keepalive.hasRespondFlag(frame)) { - timeLastTickSentMs = System.currentTimeMillis(); - } - break; - default: - // Ignore unknown frames. Throwing an error will close the socket. - errorConsumer.accept( - new IllegalStateException( - "Client received supported frame on stream 0: " + frame.toString())); - } - } - - private void handleFrame(int streamId, FrameType type, Frame frame) { - Subscriber receiver; - synchronized (this) { - receiver = receivers.get(streamId); - } - if (receiver == null) { - handleMissingResponseProcessor(streamId, type, frame); - } else { - switch (type) { - case ERROR: - receiver.onError(Exceptions.from(frame)); - removeReceiver(streamId); - break; - case NEXT_COMPLETE: - receiver.onNext(new PayloadImpl(frame)); - receiver.onComplete(); - break; - case CANCEL: - { - LimitableRequestPublisher sender; - synchronized (this) { - sender = senders.remove(streamId); - removeReceiver(streamId); - } - if (sender != null) { - sender.cancel(); - } - break; - } - case NEXT: - receiver.onNext(new PayloadImpl(frame)); - break; - case REQUEST_N: - { - LimitableRequestPublisher sender; - synchronized (this) { - sender = senders.get(streamId); - } - if (sender != null) { - int n = Frame.RequestN.requestN(frame); - sender.increaseRequestLimit(n); - } - break; - } - case COMPLETE: - receiver.onComplete(); - synchronized (this) { - receivers.remove(streamId); - } - break; - default: - throw new IllegalStateException( - "Client received supported frame on stream " + streamId + ": " + frame.toString()); - } - } - } - - private void handleMissingResponseProcessor(int streamId, FrameType type, Frame frame) { - if (!streamIdSupplier.isBeforeOrCurrent(streamId)) { - if (type == FrameType.ERROR) { - // message for stream that has never existed, we have a problem with - // the overall connection and must tear down - String errorMessage = frame.getDataUtf8(); - - throw new IllegalStateException( - "Client received error for non-existent stream: " - + streamId - + " Message: " - + errorMessage); - } else { - throw new IllegalStateException( - "Client received message for non-existent stream: " - + streamId - + ", frame type: " - + type); - } - } - // receiving a frame after a given stream has been cancelled/completed, - // so ignore (cancellation is async so there is a race condition) - } - - private synchronized void removeReceiver(int streamId) { - receivers.remove(streamId); - } - - private synchronized void removeSender(int streamId) { - senders.remove(streamId); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java new file mode 100644 index 000000000..b43b14bae --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java @@ -0,0 +1,82 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import reactor.util.annotation.Nullable; + +/** + * Exception that represents an RSocket protocol error. + * + * @see ERROR + * Frame (0x0B) + */ +public class RSocketErrorException extends RuntimeException { + + private static final long serialVersionUID = -1628781753426267554L; + + private static final int MIN_ERROR_CODE = 0x00000001; + + private static final int MAX_ERROR_CODE = 0xFFFFFFFE; + + private final int errorCode; + + /** + * Constructor with a protocol error code and a message. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + */ + public RSocketErrorException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Alternative to {@link #RSocketErrorException(int, String)} with a root cause. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + * @param cause a root cause for the error + */ + public RSocketErrorException(int errorCode, String message, @Nullable Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + if (errorCode > MAX_ERROR_CODE && errorCode < MIN_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000001-0xFFFFFFFE]", this); + } + } + + /** + * Return the RSocket error code + * represented by this exception + * + * @return the RSocket protocol error code + */ + public int errorCode() { + return errorCode; + } + + @Override + public String toString() { + return getClass().getSimpleName() + + " (0x" + + Integer.toHexString(errorCode) + + "): " + + getMessage(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java b/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java deleted file mode 100644 index 21553f228..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketFactory.java +++ /dev/null @@ -1,406 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import io.rsocket.exceptions.InvalidSetupException; -import io.rsocket.fragmentation.FragmentationDuplexConnection; -import io.rsocket.frame.SetupFrameFlyweight; -import io.rsocket.frame.VersionFlyweight; -import io.rsocket.internal.ClientServerInputMultiplexer; -import io.rsocket.plugins.DuplexConnectionInterceptor; -import io.rsocket.plugins.PluginRegistry; -import io.rsocket.plugins.Plugins; -import io.rsocket.plugins.RSocketInterceptor; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.util.PayloadImpl; -import java.time.Duration; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; -import reactor.core.publisher.Mono; - -/** Factory for creating RSocket clients and servers. */ -public class RSocketFactory { - /** - * Creates a factory that establishes client connections to other RSockets. - * - * @return a client factory - */ - public static ClientRSocketFactory connect() { - return new ClientRSocketFactory(); - } - - /** - * Creates a factory that receives server connections from client RSockets. - * - * @return a server factory. - */ - public static ServerRSocketFactory receive() { - return new ServerRSocketFactory(); - } - - public interface Start { - Mono start(); - } - - public interface SetupPayload { - T setupPayload(Payload payload); - } - - public interface Acceptor { - T acceptor(Supplier acceptor); - - default T acceptor(A acceptor) { - return acceptor(() -> acceptor); - } - } - - public interface ClientTransportAcceptor { - Start transport(Supplier transport); - - default Start transport(ClientTransport transport) { - return transport(() -> transport); - } - } - - public interface ServerTransportAcceptor { - Start transport(Supplier> transport); - - default Start transport(ServerTransport transport) { - return transport(() -> transport); - } - } - - public interface Fragmentation { - T fragment(int mtu); - } - - public interface ErrorConsumer { - T errorConsumer(Consumer errorConsumer); - } - - public interface KeepAlive { - T keepAlive(); - - T keepAlive(Duration tickPeriod, Duration ackTimeout, int missedAcks); - - T keepAliveTickPeriod(Duration tickPeriod); - - T keepAliveAckTimeout(Duration ackTimeout); - - T keepAliveMissedAcks(int missedAcks); - } - - public interface MimeType { - T mimeType(String metadataMimeType, String dataMimeType); - - T dataMimeType(String dataMimeType); - - T metadataMimeType(String metadataMimeType); - } - - public static class ClientRSocketFactory - implements Acceptor>, - ClientTransportAcceptor, - KeepAlive, - MimeType, - Fragmentation, - ErrorConsumer, - SetupPayload { - - private Supplier> acceptor = - () -> rSocket -> new AbstractRSocket() {}; - - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); - private int flags = SetupFrameFlyweight.FLAGS_STRICT_INTERPRETATION; - - private Payload setupPayload = PayloadImpl.EMPTY; - - private Duration tickPeriod = Duration.ZERO; - private Duration ackTimeout = Duration.ofSeconds(30); - private int missedAcks = 3; - - private String metadataMimeType = "application/binary"; - private String dataMimeType = "application/binary"; - - public ClientRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); - return this; - } - - public ClientRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - plugins.addClientPlugin(interceptor); - return this; - } - - public ClientRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - plugins.addServerPlugin(interceptor); - return this; - } - - @Override - public ClientRSocketFactory keepAlive() { - tickPeriod = Duration.ofSeconds(20); - return this; - } - - @Override - public ClientRSocketFactory keepAlive( - Duration tickPeriod, Duration ackTimeout, int missedAcks) { - this.tickPeriod = tickPeriod; - this.ackTimeout = ackTimeout; - this.missedAcks = missedAcks; - return this; - } - - @Override - public ClientRSocketFactory keepAliveTickPeriod(Duration tickPeriod) { - this.tickPeriod = tickPeriod; - return this; - } - - @Override - public ClientRSocketFactory keepAliveAckTimeout(Duration ackTimeout) { - this.ackTimeout = ackTimeout; - return this; - } - - @Override - public ClientRSocketFactory keepAliveMissedAcks(int missedAcks) { - this.missedAcks = missedAcks; - return this; - } - - @Override - public ClientRSocketFactory mimeType(String metadataMimeType, String dataMimeType) { - this.dataMimeType = dataMimeType; - this.metadataMimeType = metadataMimeType; - return this; - } - - @Override - public ClientRSocketFactory dataMimeType(String dataMimeType) { - this.dataMimeType = dataMimeType; - return this; - } - - @Override - public ClientRSocketFactory metadataMimeType(String metadataMimeType) { - this.metadataMimeType = metadataMimeType; - return this; - } - - @Override - public Start transport(Supplier transportClient) { - return new StartClient(transportClient); - } - - @Override - public ClientTransportAcceptor acceptor(Supplier> acceptor) { - this.acceptor = acceptor; - return StartClient::new; - } - - @Override - public ClientRSocketFactory fragment(int mtu) { - this.mtu = mtu; - return this; - } - - @Override - public ClientRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; - return this; - } - - @Override - public ClientRSocketFactory setupPayload(Payload payload) { - this.setupPayload = payload; - return this; - } - - protected class StartClient implements Start { - private final Supplier transportClient; - - StartClient(Supplier transportClient) { - this.transportClient = transportClient; - } - - @Override - public Mono start() { - return transportClient - .get() - .connect() - .flatMap( - connection -> { - Frame setupFrame = - Frame.Setup.from( - flags, - (int) ackTimeout.toMillis(), - (int) ackTimeout.toMillis() * missedAcks, - metadataMimeType, - dataMimeType, - setupPayload); - - if (mtu > 0) { - connection = new FragmentationDuplexConnection(connection, mtu); - } - - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, plugins); - - RSocketClient rSocketClient = - new RSocketClient( - multiplexer.asClientConnection(), - errorConsumer, - StreamIdSupplier.clientSupplier(), - tickPeriod, - ackTimeout, - missedAcks); - - Mono wrappedRSocketClient = - Mono.just(rSocketClient).map(plugins::applyClient); - - DuplexConnection finalConnection = connection; - return wrappedRSocketClient.flatMap( - wrappedClientRSocket -> { - RSocket unwrappedServerSocket = acceptor.get().apply(wrappedClientRSocket); - - Mono wrappedRSocketServer = - Mono.just(unwrappedServerSocket).map(plugins::applyServer); - - return wrappedRSocketServer - .doOnNext( - rSocket -> - new RSocketServer( - multiplexer.asServerConnection(), rSocket, errorConsumer)) - .then(finalConnection.sendOne(setupFrame)) - .then(wrappedRSocketClient); - }); - }); - } - } - } - - public static class ServerRSocketFactory - implements Acceptor, - Fragmentation, - ErrorConsumer { - - private Supplier acceptor; - private Consumer errorConsumer = Throwable::printStackTrace; - private int mtu = 0; - private PluginRegistry plugins = new PluginRegistry(Plugins.defaultPlugins()); - - private ServerRSocketFactory() {} - - public ServerRSocketFactory addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - plugins.addConnectionPlugin(interceptor); - return this; - } - - public ServerRSocketFactory addClientPlugin(RSocketInterceptor interceptor) { - plugins.addClientPlugin(interceptor); - return this; - } - - public ServerRSocketFactory addServerPlugin(RSocketInterceptor interceptor) { - plugins.addServerPlugin(interceptor); - return this; - } - - @Override - public ServerTransportAcceptor acceptor(Supplier acceptor) { - this.acceptor = acceptor; - return ServerStart::new; - } - - @Override - public ServerRSocketFactory fragment(int mtu) { - this.mtu = mtu; - return this; - } - - @Override - public ServerRSocketFactory errorConsumer(Consumer errorConsumer) { - this.errorConsumer = errorConsumer; - return this; - } - - private class ServerStart implements Start { - private final Supplier> transportServer; - - ServerStart(Supplier> transportServer) { - this.transportServer = transportServer; - } - - @Override - public Mono start() { - return transportServer - .get() - .start( - connection -> { - if (mtu > 0) { - connection = new FragmentationDuplexConnection(connection, mtu); - } - - ClientServerInputMultiplexer multiplexer = - new ClientServerInputMultiplexer(connection, plugins); - - return multiplexer - .asStreamZeroConnection() - .receive() - .next() - .flatMap(setupFrame -> processSetupFrame(multiplexer, setupFrame)); - }); - } - - private Mono processSetupFrame( - ClientServerInputMultiplexer multiplexer, Frame setupFrame) { - int version = Frame.Setup.version(setupFrame); - if (version != SetupFrameFlyweight.CURRENT_VERSION) { - InvalidSetupException error = - new InvalidSetupException( - "Unsupported version " + VersionFlyweight.toString(version)); - return multiplexer - .asStreamZeroConnection() - .sendOne(Frame.Error.from(0, error)) - .then(multiplexer.close()); - } - - ConnectionSetupPayload setupPayload = ConnectionSetupPayload.create(setupFrame); - - RSocketClient rSocketClient = - new RSocketClient( - multiplexer.asServerConnection(), errorConsumer, StreamIdSupplier.serverSupplier()); - - Mono wrappedRSocketClient = Mono.just(rSocketClient).map(plugins::applyClient); - - return wrappedRSocketClient - .flatMap( - sender -> acceptor.get().accept(setupPayload, sender).map(plugins::applyServer)) - .map( - handler -> - new RSocketServer(multiplexer.asClientConnection(), handler, errorConsumer)) - .then(); - } - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/RSocketServer.java deleted file mode 100644 index d6138aace..000000000 --- a/rsocket-core/src/main/java/io/rsocket/RSocketServer.java +++ /dev/null @@ -1,423 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.Frame.Request.initialRequestN; -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_C; -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.util.collection.IntObjectHashMap; -import io.rsocket.exceptions.ApplicationException; -import io.rsocket.internal.LimitableRequestPublisher; -import io.rsocket.util.PayloadImpl; -import java.util.Collection; -import java.util.function.Consumer; -import javax.annotation.Nullable; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.Disposable; -import reactor.core.publisher.*; - -/** Server side RSocket. Receives {@link Frame}s from a {@link RSocketClient} */ -class RSocketServer implements RSocket { - - private final DuplexConnection connection; - private final RSocket requestHandler; - private final Consumer errorConsumer; - - private final IntObjectHashMap sendingSubscriptions; - private final IntObjectHashMap> channelProcessors; - - private final EmitterProcessor sendProcessor; - private Disposable receiveDisposable; - - RSocketServer( - DuplexConnection connection, RSocket requestHandler, Consumer errorConsumer) { - this.connection = connection; - this.requestHandler = requestHandler; - this.errorConsumer = errorConsumer; - this.sendingSubscriptions = new IntObjectHashMap<>(); - this.channelProcessors = new IntObjectHashMap<>(); - - // DO NOT Change the order here. The Send processor must be subscribed to before receiving - // connections - this.sendProcessor = EmitterProcessor.create(); - - connection - .send(sendProcessor) - .doOnError(this::handleSendProcessorError) - .doFinally(this::handleSendProcessorCancel) - .subscribe(); - - this.receiveDisposable = - connection.receive().flatMap(this::handleFrame).doOnError(errorConsumer).then().subscribe(); - - this.connection - .onClose() - .doOnError(errorConsumer) - .doFinally( - s -> { - cleanup(); - receiveDisposable.dispose(); - }) - .subscribe(); - } - - private void handleSendProcessorError(Throwable t) { - Collection values; - Collection> values1; - synchronized (RSocketServer.this) { - values = sendingSubscriptions.values(); - values1 = channelProcessors.values(); - } - - for (Subscription subscription : values) { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - } - - for (UnicastProcessor subscription : values1) { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - } - } - - private void handleSendProcessorCancel(SignalType t) { - if (SignalType.ON_ERROR == t) { - return; - } - Collection values; - Collection> values1; - synchronized (RSocketServer.this) { - values = sendingSubscriptions.values(); - values1 = channelProcessors.values(); - } - - for (Subscription subscription : values) { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - } - - for (UnicastProcessor subscription : values1) { - try { - subscription.cancel(); - } catch (Throwable e) { - errorConsumer.accept(e); - } - } - } - - @Override - public Mono fireAndForget(Payload payload) { - try { - return requestHandler.fireAndForget(payload); - } catch (Throwable t) { - return Mono.error(t); - } - } - - @Override - public Mono requestResponse(Payload payload) { - try { - return requestHandler.requestResponse(payload); - } catch (Throwable t) { - return Mono.error(t); - } - } - - @Override - public Flux requestStream(Payload payload) { - try { - return requestHandler.requestStream(payload); - } catch (Throwable t) { - return Flux.error(t); - } - } - - @Override - public Flux requestChannel(Publisher payloads) { - try { - return requestHandler.requestChannel(payloads); - } catch (Throwable t) { - return Flux.error(t); - } - } - - @Override - public Mono metadataPush(Payload payload) { - try { - return requestHandler.metadataPush(payload); - } catch (Throwable t) { - return Mono.error(t); - } - } - - @Override - public Mono close() { - return connection.close(); - } - - @Override - public Mono onClose() { - return connection.onClose(); - } - - private void cleanup() { - cleanUpSendingSubscriptions(); - cleanUpChannelProcessors(); - - requestHandler.close().subscribe(); - } - - private synchronized void cleanUpSendingSubscriptions() { - sendingSubscriptions.values().forEach(Subscription::cancel); - sendingSubscriptions.clear(); - } - - private synchronized void cleanUpChannelProcessors() { - channelProcessors.values().forEach(Subscription::cancel); - channelProcessors.clear(); - } - - private Mono handleFrame(Frame frame) { - try { - int streamId = frame.getStreamId(); - Subscriber receiver; - switch (frame.getType()) { - case FIRE_AND_FORGET: - return handleFireAndForget(streamId, fireAndForget(new PayloadImpl(frame))); - case REQUEST_RESPONSE: - return handleRequestResponse(streamId, requestResponse(new PayloadImpl(frame))); - case CANCEL: - return handleCancelFrame(streamId); - case KEEPALIVE: - return handleKeepAliveFrame(frame); - case REQUEST_N: - return handleRequestN(streamId, frame); - case REQUEST_STREAM: - return handleStream( - streamId, requestStream(new PayloadImpl(frame)), initialRequestN(frame)); - case REQUEST_CHANNEL: - return handleChannel(streamId, frame); - case PAYLOAD: - // TODO: Hook in receiving socket. - return Mono.empty(); - case METADATA_PUSH: - return metadataPush(new PayloadImpl(frame)); - case LEASE: - // Lease must not be received here as this is the server end of the socket which sends - // leases. - return Mono.empty(); - case NEXT: - receiver = getChannelProcessor(streamId); - if (receiver != null) { - receiver.onNext(new PayloadImpl(frame)); - } - return Mono.empty(); - case COMPLETE: - receiver = getChannelProcessor(streamId); - if (receiver != null) { - receiver.onComplete(); - } - return Mono.empty(); - case ERROR: - receiver = getChannelProcessor(streamId); - if (receiver != null) { - receiver.onError(new ApplicationException(Frame.Error.message(frame))); - } - return Mono.empty(); - case NEXT_COMPLETE: - receiver = getChannelProcessor(streamId); - if (receiver != null) { - receiver.onNext(new PayloadImpl(frame)); - receiver.onComplete(); - } - - return Mono.empty(); - - case SETUP: - return handleError( - streamId, new IllegalStateException("Setup frame received post setup.")); - default: - return handleError( - streamId, - new IllegalStateException( - "ServerRSocket: Unexpected frame type: " + frame.getType())); - } - } finally { - frame.release(); - } - } - - private Mono handleFireAndForget(int streamId, Mono result) { - return result - .doOnSubscribe(subscription -> addSubscription(streamId, subscription)) - .doOnError(errorConsumer) - .doFinally(signalType -> removeSubscription(streamId)) - .ignoreElement(); - } - - private Mono handleRequestResponse(int streamId, Mono response) { - return response - .doOnSubscribe(subscription -> addSubscription(streamId, subscription)) - .map( - payload -> { - int flags = FLAGS_C; - if (payload.hasMetadata()) { - flags = Frame.setFlag(flags, FLAGS_M); - } - return Frame.PayloadFrame.from(streamId, FrameType.NEXT_COMPLETE, payload, flags); - }) - .doOnError(errorConsumer) - .onErrorResume(t -> Mono.just(Frame.Error.from(streamId, t))) - .doOnNext(sendProcessor::onNext) - .doFinally(signalType -> removeSubscription(streamId)) - .then(); - } - - private Mono handleStream(int streamId, Flux response, int initialRequestN) { - response - .map(payload -> Frame.PayloadFrame.from(streamId, FrameType.NEXT, payload)) - .transform( - frameFlux -> { - LimitableRequestPublisher frames = LimitableRequestPublisher.wrap(frameFlux); - synchronized (RSocketServer.this) { - sendingSubscriptions.put(streamId, frames); - } - frames.increaseRequestLimit(initialRequestN); - return frames; - }) - .concatWith(Mono.just(Frame.PayloadFrame.from(streamId, FrameType.COMPLETE))) - .onErrorResume(t -> Mono.just(Frame.Error.from(streamId, t))) - .doOnNext(sendProcessor::onNext) - .doFinally(signalType -> removeSubscription(streamId)) - .subscribe(); - - return Mono.empty(); - } - - private Mono handleChannel(int streamId, Frame firstFrame) { - UnicastProcessor frames = UnicastProcessor.create(); - addChannelProcessor(streamId, frames); - - Flux payloads = - frames - .doOnCancel( - () -> { - if (connection.availability() > 0.0) { - sendProcessor.onNext(Frame.Cancel.from(streamId)); - } - }) - .doOnError( - t -> { - if (connection.availability() > 0.0) { - sendProcessor.onNext(Frame.Error.from(streamId, t)); - } - }) - .doOnRequest( - l -> { - if (connection.availability() > 0.0) { - sendProcessor.onNext(Frame.RequestN.from(streamId, l)); - } - }) - .doFinally(signalType -> removeChannelProcessor(streamId)); - - // not chained, as the payload should be enqueued in the Unicast processor before this method - // returns - // and any later payload can be processed - frames.onNext(new PayloadImpl(firstFrame)); - - return handleStream(streamId, requestChannel(payloads), initialRequestN(firstFrame)); - } - - private Mono handleKeepAliveFrame(Frame frame) { - return Mono.fromRunnable( - () -> { - if (Frame.Keepalive.hasRespondFlag(frame)) { - ByteBuf data = Unpooled.wrappedBuffer(frame.getData()); - sendProcessor.onNext(Frame.Keepalive.from(data, false)); - } - }); - } - - private Mono handleCancelFrame(int streamId) { - return Mono.fromRunnable( - () -> { - Subscription subscription; - synchronized (this) { - subscription = sendingSubscriptions.remove(streamId); - } - - if (subscription != null) { - subscription.cancel(); - } - }); - } - - private Mono handleError(int streamId, Throwable t) { - return Mono.fromRunnable( - () -> { - errorConsumer.accept(t); - sendProcessor.onNext(Frame.Error.from(streamId, t)); - }); - } - - private Mono handleRequestN(int streamId, Frame frame) { - final Subscription subscription = getSubscription(streamId); - if (subscription != null) { - int n = Frame.RequestN.requestN(frame); - subscription.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n); - } - return Mono.empty(); - } - - private synchronized void addSubscription(int streamId, Subscription subscription) { - sendingSubscriptions.put(streamId, subscription); - } - - private synchronized @Nullable Subscription getSubscription(int streamId) { - return sendingSubscriptions.get(streamId); - } - - private synchronized void removeSubscription(int streamId) { - sendingSubscriptions.remove(streamId); - } - - private synchronized void addChannelProcessor(int streamId, UnicastProcessor processor) { - channelProcessors.put(streamId, processor); - } - - private synchronized @Nullable UnicastProcessor getChannelProcessor(int streamId) { - return channelProcessors.get(streamId); - } - - private synchronized void removeChannelProcessor(int streamId) { - channelProcessors.remove(streamId); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java index b54ff8601..a42626e78 100644 --- a/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java @@ -1,40 +1,93 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket; import io.rsocket.exceptions.SetupException; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** - * {@code RSocket} is a full duplex protocol where a client and server are identical in terms of - * both having the capability to initiate requests to their peer. This interface provides the - * contract where a server accepts a new {@code RSocket} for sending requests to the peer and - * returns a new {@code RSocket} that will be used to accept requests from it's peer. + * RSocket is a full duplex protocol where a client and server are identical in terms of both having + * the capability to initiate requests to their peer. This interface provides the contract where a + * client or server handles the {@code setup} for a new connection and creates a responder {@code + * RSocket} for accepting requests from the remote peer. */ public interface SocketAcceptor { /** - * Accepts a new {@code RSocket} used to send requests to the peer and returns another {@code - * RSocket} that is used for accepting requests from the peer. + * Handle the {@code SETUP} frame for a new connection and create a responder {@code RSocket} for + * handling requests from the remote peer. * - * @param setup Setup as sent by the client. - * @param sendingSocket Socket used to send requests to the peer. - * @return Socket to accept requests from the peer. + * @param setup the {@code setup} received from a client in a server scenario, or in a client + * scenario this is the setup about to be sent to the server. + * @param sendingSocket socket for sending requests to the remote peer. + * @return {@code RSocket} to accept requests with. * @throws SetupException If the acceptor needs to reject the setup of this socket. */ Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket); + + /** Create a {@code SocketAcceptor} that handles requests with the given {@code RSocket}. */ + static SocketAcceptor with(RSocket rsocket) { + return (setup, sendingRSocket) -> Mono.just(rsocket); + } + + /** Create a {@code SocketAcceptor} for fire-and-forget interactions with the given handler. */ + static SocketAcceptor forFireAndForget(Function> handler) { + return with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-response interactions with the given handler. */ + static SocketAcceptor forRequestResponse(Function> handler) { + return with( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-stream interactions with the given handler. */ + static SocketAcceptor forRequestStream(Function> handler) { + return with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-channel interactions with the given handler. */ + static SocketAcceptor forRequestChannel(Function, Flux> handler) { + return with( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return handler.apply(payloads); + } + }); + } } diff --git a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java b/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java deleted file mode 100644 index 07607ce6e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/StreamIdSupplier.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -final class StreamIdSupplier { - - private int streamId; - - private StreamIdSupplier(int streamId) { - this.streamId = streamId; - } - - synchronized int nextStreamId() { - streamId += 2; - return streamId; - } - - synchronized boolean isBeforeOrCurrent(int streamId) { - return this.streamId >= streamId && streamId > 0; - } - - static StreamIdSupplier clientSupplier() { - return new StreamIdSupplier(-1); - } - - static StreamIdSupplier serverSupplier() { - return new StreamIdSupplier(0); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java new file mode 100644 index 000000000..e19d31924 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java @@ -0,0 +1,348 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; + +/** + * {@link DuplexConnection#receive()} is a single stream on which the following type of frames + * arrive: + * + *

+ * + *

The only way to differentiate these two frames is determining whether the stream Id is odd or + * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the + * client. + */ +class ClientServerInputMultiplexer implements CoreSubscriber, Closeable { + + private final InternalDuplexConnection serverReceiver; + private final InternalDuplexConnection clientReceiver; + private final DuplexConnection serverConnection; + private final DuplexConnection clientConnection; + private final DuplexConnection source; + private final boolean isClient; + + private Subscription s; + + private Throwable t; + + private volatile int state; + private static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ClientServerInputMultiplexer.class, "state"); + + public ClientServerInputMultiplexer( + DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { + this.source = source; + this.isClient = isClient; + + this.serverReceiver = new InternalDuplexConnection(Type.SERVER, this, source); + this.clientReceiver = new InternalDuplexConnection(Type.CLIENT, this, source); + this.serverConnection = registry.initConnection(Type.SERVER, serverReceiver); + this.clientConnection = registry.initConnection(Type.CLIENT, clientReceiver); + } + + DuplexConnection asServerConnection() { + return serverConnection; + } + + DuplexConnection asClientConnection() { + return clientConnection; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(ByteBuf frame) { + int streamId = FrameHeaderCodec.streamId(frame); + final Type type; + if (streamId == 0) { + switch (FrameHeaderCodec.frameType(frame)) { + case LEASE: + case KEEPALIVE: + case ERROR: + type = isClient ? Type.CLIENT : Type.SERVER; + break; + default: + type = isClient ? Type.SERVER : Type.CLIENT; + } + } else if ((streamId & 0b1) == 0) { + type = Type.SERVER; + } else { + type = Type.CLIENT; + } + + switch (type) { + case CLIENT: + clientReceiver.onNext(frame); + break; + case SERVER: + serverReceiver.onNext(frame); + break; + } + } + + @Override + public void onComplete() { + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onComplete(); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + this.t = t; + + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onError(t); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onError(t); + } + } + + boolean notifyRequested() { + final int currentState = incrementAndGetCheckingState(); + if (currentState == Integer.MIN_VALUE) { + return false; + } + + if (currentState == 2) { + source.receive().subscribe(this); + } + + return true; + } + + int incrementAndGetCheckingState() { + int prev, next; + for (; ; ) { + prev = this.state; + + if (prev == Integer.MIN_VALUE) { + return prev; + } + + next = prev + 1; + if (STATE.compareAndSet(this, prev, next)) { + return next; + } + } + } + + @Override + public String toString() { + return "ClientServerInputMultiplexer{" + + "serverReceiver=" + + serverReceiver + + ", clientReceiver=" + + clientReceiver + + ", serverConnection=" + + serverConnection + + ", clientConnection=" + + clientConnection + + ", source=" + + source + + ", isClient=" + + isClient + + ", s=" + + s + + ", t=" + + t + + ", state=" + + state + + '}'; + } + + private static class InternalDuplexConnection extends Flux + implements Subscription, DuplexConnection { + private final Type type; + private final ClientServerInputMultiplexer clientServerInputMultiplexer; + private final DuplexConnection source; + + private volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(InternalDuplexConnection.class, "state"); + + CoreSubscriber actual; + + public InternalDuplexConnection( + Type type, + ClientServerInputMultiplexer clientServerInputMultiplexer, + DuplexConnection source) { + this.type = type; + this.clientServerInputMultiplexer = clientServerInputMultiplexer; + this.source = source; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (this.state == 0 && STATE.compareAndSet(this, 0, 1)) { + this.actual = actual; + actual.onSubscribe(this); + } else { + Operators.error( + actual, + new IllegalStateException("InternalDuplexConnection allows only single subscription")); + } + } + + @Override + public void request(long n) { + if (this.state == 1 && STATE.compareAndSet(this, 1, 2)) { + final ClientServerInputMultiplexer multiplexer = clientServerInputMultiplexer; + if (!multiplexer.notifyRequested()) { + final Throwable t = multiplexer.t; + if (t != null) { + this.actual.onError(t); + } else { + this.actual.onComplete(); + } + } + } + } + + @Override + public void cancel() { + // no ops + } + + void onNext(ByteBuf frame) { + this.actual.onNext(frame); + } + + void onComplete() { + this.actual.onComplete(); + } + + void onError(Throwable t) { + this.actual.onError(t); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + public boolean isSubscribed() { + return this.state != 0; + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public double availability() { + return source.availability(); + } + + @Override + public String toString() { + return "InternalDuplexConnection{" + + "type=" + + type + + ", source=" + + source + + ", state=" + + state + + '}'; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java new file mode 100644 index 000000000..3477b8d6d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java @@ -0,0 +1,49 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import java.nio.channels.ClosedChannelException; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +abstract class ClientSetup { + abstract Mono> init(DuplexConnection connection); +} + +class DefaultClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> sink.onRequest(__ -> sink.success(Tuples.of(Unpooled.EMPTY_BUFFER, connection)))); + } +} + +class ResumableClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> { + sink.onRequest( + __ -> { + new SetupHandlingDuplexConnection(connection, sink); + }); + + Disposable subscribe = + connection + .onClose() + .doFinally(__ -> sink.error(new ClosedChannelException())) + .subscribe(); + sink.onCancel( + () -> { + subscribe.dispose(); + connection.dispose(); + connection.receive().subscribe(); + }); + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java new file mode 100644 index 000000000..9b5647c6f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; + +/** + * Default implementation of {@link ConnectionSetupPayload}. Primarily for internal use within + * RSocket Java but may be created in an application, e.g. for testing purposes. + */ +public class DefaultConnectionSetupPayload extends ConnectionSetupPayload { + + private final ByteBuf setupFrame; + + public DefaultConnectionSetupPayload(ByteBuf setupFrame) { + this.setupFrame = setupFrame; + } + + @Override + public boolean hasMetadata() { + return FrameHeaderCodec.hasMetadata(setupFrame); + } + + @Override + public ByteBuf sliceMetadata() { + final ByteBuf metadata = SetupFrameCodec.metadata(setupFrame); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + return SetupFrameCodec.data(setupFrame); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public String metadataMimeType() { + return SetupFrameCodec.metadataMimeType(setupFrame); + } + + @Override + public String dataMimeType() { + return SetupFrameCodec.dataMimeType(setupFrame); + } + + @Override + public int keepAliveInterval() { + return SetupFrameCodec.keepAliveInterval(setupFrame); + } + + @Override + public int keepAliveMaxLifetime() { + return SetupFrameCodec.keepAliveMaxLifetime(setupFrame); + } + + @Override + public int getFlags() { + return FrameHeaderCodec.flags(setupFrame); + } + + @Override + public boolean willClientHonorLease() { + return SetupFrameCodec.honorLease(setupFrame); + } + + @Override + public boolean isResumeEnabled() { + return SetupFrameCodec.resumeEnabled(setupFrame); + } + + @Override + public ByteBuf resumeToken() { + return SetupFrameCodec.resumeToken(setupFrame); + } + + @Override + public ConnectionSetupPayload touch() { + setupFrame.touch(); + return this; + } + + @Override + public ConnectionSetupPayload touch(Object hint) { + setupFrame.touch(hint); + return this; + } + + @Override + protected void deallocate() { + setupFrame.release(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java new file mode 100644 index 000000000..82a02268d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java @@ -0,0 +1,562 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import java.util.AbstractMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CorePublisher; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoOperator; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * Default implementation of {@link RSocketClient} + * + * @since 1.0.1 + */ +class DefaultRSocketClient extends ResolvingOperator + implements CoreSubscriber, CorePublisher, RSocketClient { + static final Consumer DISCARD_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + ReferenceCounted referenceCounted = ((ReferenceCounted) data); + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + } + }; + + static final Object ON_DISCARD_KEY; + + static { + Context discardAwareContext = Operators.enableOnDiscard(null, DISCARD_ELEMENTS_CONSUMER); + ON_DISCARD_KEY = discardAwareContext.stream().findFirst().get().getKey(); + } + + final Mono source; + + final Sinks.Empty onDisposeSink; + + volatile Subscription s; + + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(DefaultRSocketClient.class, Subscription.class, "s"); + + DefaultRSocketClient(Mono source) { + this.source = unwrapReconnectMono(source); + this.onDisposeSink = Sinks.empty(); + } + + private Mono unwrapReconnectMono(Mono source) { + return source instanceof ReconnectMono ? ((ReconnectMono) source).getSource() : source; + } + + @Override + public Mono onClose() { + return this.onDisposeSink.asMono(); + } + + @Override + public Mono source() { + return Mono.fromDirect(this); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_FNF, payloadMono); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_RESPONSE, payloadMono); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_STREAM, payloadMono); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_CHANNEL, payloads); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return new RSocketClientMonoOperator<>(this, FrameType.METADATA_PUSH, payloadMono); + } + + @Override + @SuppressWarnings("uncheked") + public void subscribe(CoreSubscriber actual) { + final ResolvingOperator.MonoDeferredResolutionOperator inner = + new ResolvingOperator.MonoDeferredResolutionOperator<>(this, actual); + actual.onSubscribe(inner); + + this.observe(inner); + } + + @Override + public void subscribe(Subscriber s) { + subscribe(Operators.toCoreSubscriber(s)); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final RSocket value = this.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + this.doFinally(); + return; + } + + if (value == null) { + this.terminate(new IllegalStateException("Source completed empty")); + } else { + this.complete(value); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + this.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doFinally(); + // terminate upstream which means retryBackoff has exhausted + this.terminate(t); + } + + @Override + public void onNext(RSocket value) { + if (this.s == Operators.cancelledSubscription()) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + // volatile write and check on racing + this.doFinally(); + } + + @Override + protected void doSubscribe() { + this.source.subscribe(this); + } + + @Override + protected void doOnValueResolved(RSocket value) { + value.onClose().subscribe(null, t -> this.invalidate(), this::invalidate); + } + + @Override + protected void doOnValueExpired(RSocket value) { + value.dispose(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onDisposeSink::tryEmitError, onDisposeSink::tryEmitEmpty); + } else { + onDisposeSink.tryEmitEmpty(); + } + } + + static final class FlatMapMain implements CoreSubscriber, Context, Scannable { + + final DefaultRSocketClient parent; + final CoreSubscriber actual; + + final FlattingInner second; + + Subscription s; + + boolean done; + + FlatMapMain( + DefaultRSocketClient parent, CoreSubscriber actual, FrameType requestType) { + this.parent = parent; + this.actual = actual; + this.second = new FlattingInner<>(parent, this, actual, requestType); + } + + @Override + public Context currentContext() { + return this; + } + + @Override + public Stream inners() { + return Stream.of(this.second); + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.s; + if (key == Attr.CANCELLED) return this.second.isCancelled(); + if (key == Attr.TERMINATED) return this.done; + + return null; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + this.actual.onSubscribe(this.second); + } + } + + @Override + public void onNext(Payload payload) { + if (this.done) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + return; + } + this.done = true; + + final FlattingInner inner = this.second; + + if (inner.isCancelled()) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + return; + } + + inner.payload = payload; + + if (inner.isCancelled()) { + if (FlattingInner.PAYLOAD.compareAndSet(inner, payload, null)) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + } + return; + } + + this.parent.observe(inner); + } + + @Override + public void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + this.done = true; + + this.actual.onError(t); + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + this.done = true; + + this.actual.onComplete(); + } + + void request(long n) { + this.s.request(n); + } + + void cancel() { + this.s.cancel(); + } + + @Override + @SuppressWarnings("unchecked") + public K get(Object key) { + if (key == ON_DISCARD_KEY) { + return (K) DISCARD_ELEMENTS_CONSUMER; + } + return this.actual.currentContext().get(key); + } + + @Override + public boolean hasKey(Object key) { + if (key == ON_DISCARD_KEY) { + return true; + } + return this.actual.currentContext().hasKey(key); + } + + @Override + public Context put(Object key, Object value) { + return this.actual + .currentContext() + .put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER) + .put(key, value); + } + + @Override + public Context delete(Object key) { + return this.actual + .currentContext() + .put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER) + .delete(key); + } + + @Override + public int size() { + return this.actual.currentContext().size() + 1; + } + + @Override + public Stream> stream() { + return Stream.concat( + Stream.of( + new AbstractMap.SimpleImmutableEntry<>(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)), + this.actual.currentContext().stream()); + } + } + + static final class FlattingInner extends DeferredResolution { + + final FlatMapMain main; + final FrameType interactionType; + + volatile Payload payload; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater PAYLOAD = + AtomicReferenceFieldUpdater.newUpdater(FlattingInner.class, Payload.class, "payload"); + + FlattingInner( + DefaultRSocketClient parent, + FlatMapMain main, + CoreSubscriber actual, + FrameType interactionType) { + super(parent, actual); + + this.main = main; + this.interactionType = interactionType; + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(RSocket rSocket, Throwable t) { + if (this.isCancelled()) { + return; + } + + Payload payload = PAYLOAD.getAndSet(this, null); + + // means cancelled + if (payload == null) { + return; + } + + if (t != null) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + onError(t); + return; + } + + CorePublisher source; + switch (this.interactionType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(payload); + break; + case REQUEST_STREAM: + source = rSocket.requestStream(payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(payload); + break; + default: + this.onError(new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + } + + @Override + public void request(long n) { + super.request(n); + this.main.request(n); + } + + public void cancel() { + long state = REQUESTED.getAndSet(this, STATE_CANCELLED); + if (state == STATE_CANCELLED) { + return; + } + + this.main.cancel(); + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + Payload payload = PAYLOAD.getAndSet(this, null); + if (payload != null) { + payload.release(); + } + } + } + } + + static final class RequestChannelInner extends DeferredResolution { + + final FrameType interactionType; + final Publisher upstream; + + RequestChannelInner( + DefaultRSocketClient parent, + Publisher upstream, + CoreSubscriber actual, + FrameType interactionType) { + super(parent, actual); + + this.upstream = upstream; + this.interactionType = interactionType; + } + + @Override + public void accept(RSocket rSocket, Throwable t) { + if (this.isCancelled()) { + return; + } + + if (t != null) { + onError(t); + return; + } + + Flux source; + if (this.interactionType == FrameType.REQUEST_CHANNEL) { + source = rSocket.requestChannel(this.upstream); + } else { + this.onError(new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + } + } + + static class RSocketClientMonoOperator extends MonoOperator { + + final DefaultRSocketClient parent; + final FrameType requestType; + + public RSocketClientMonoOperator( + DefaultRSocketClient parent, FrameType requestType, Mono source) { + super(source); + this.parent = parent; + this.requestType = requestType; + } + + @Override + public void subscribe(CoreSubscriber actual) { + this.source.subscribe(new FlatMapMain(this.parent, actual, this.requestType)); + } + } + + static class RSocketClientFluxOperator> extends Flux { + + final DefaultRSocketClient parent; + final FrameType requestType; + final ST source; + + public RSocketClientFluxOperator( + DefaultRSocketClient parent, FrameType requestType, ST source) { + this.parent = parent; + this.requestType = requestType; + this.source = source; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (requestType == FrameType.REQUEST_CHANNEL) { + RequestChannelInner inner = + new RequestChannelInner(this.parent, source, actual, requestType); + actual.onSubscribe(inner); + this.parent.observe(inner); + } else { + this.source.subscribe(new FlatMapMain<>(this.parent, actual, this.requestType)); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java new file mode 100644 index 000000000..a5d527f5c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -0,0 +1,295 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class FireAndForgetRequesterMono extends Mono implements Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(FireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; + + FireAndForgetRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + actual.onSubscribe(this); + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + actual.onError(e); + return; + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(ut); + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + if (isTerminated(this.state)) { + p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + markTerminated(STATE, this); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw e; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + throw e; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + throw Exceptions.propagate(e); + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(Exceptions.unwrap(t), FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + throw Exceptions.propagate(t); + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_FNF, + this.mtu, + this.payload, + this.connection, + this.allocator, + true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + throw Exceptions.propagate(e); + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + return null; + } + + @Override + public Object scanUnsafe(Scannable.Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java new file mode 100644 index 000000000..e76fdf9ed --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -0,0 +1,183 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +final class FireAndForgetResponderSubscriber + implements CoreSubscriber, ResponderFrameHandler { + + static final Logger logger = LoggerFactory.getLogger(FireAndForgetResponderSubscriber.class); + + static final FireAndForgetResponderSubscriber INSTANCE = new FireAndForgetResponderSubscriber(); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final RequesterResponderSupport requesterResponderSupport; + final RSocket handler; + final int maxInboundPayloadSize; + + @Nullable final RequestInterceptor requestInterceptor; + + CompositeByteBuf frames; + + private FireAndForgetResponderSubscriber() { + this.streamId = 0; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = null; + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.handler = handler; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Dropped Outbound error", t); + } + + @Override + public void onComplete() { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, null); + } + } + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + + try { + ReassemblyUtils.addFollowingFrame( + frames, followingFrame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException t) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + this.frames = null; + frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + if (!hasFollows) { + this.requesterResponderSupport.remove(this.streamId, this); + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + Mono source = this.handler.fireAndForget(payload); + source.subscribe(this); + } + } + + @Override + public final void handleCancel() { + final CompositeByteBuf frames = this.frames; + if (frames != null) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + this.frames = null; + frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java new file mode 100644 index 000000000..03b6f9e09 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java @@ -0,0 +1,224 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import reactor.util.annotation.Nullable; + +class FragmentationUtils { + + static final int MIN_MTU_SIZE = 64; + + static final int FRAME_OFFSET = // 9 bytes in total + FrameLengthCodec.FRAME_LENGTH_SIZE // includes encoded frame length bytes size + + FrameHeaderCodec.size(); // includes encoded frame headers info bytes size + static final int FRAME_OFFSET_WITH_METADATA = // 12 bytes in total + FRAME_OFFSET + + FrameLengthCodec.FRAME_LENGTH_SIZE; // include encoded metadata length bytes size + + static final int FRAME_OFFSET_WITH_INITIAL_REQUEST_N = // 13 bytes in total + FRAME_OFFSET + Integer.BYTES; // includes extra space for initialRequestN bytes size + static final int FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N = // 16 bytes in total + FRAME_OFFSET_WITH_METADATA + + Integer.BYTES; // includes extra space for initialRequestN bytes size + + static boolean isFragmentable( + int mtu, ByteBuf data, @Nullable ByteBuf metadata, boolean hasInitialRequestN) { + if (mtu == 0) { + return false; + } + + if (metadata != null) { + int remaining = + mtu + - (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA); + + return (metadata.readableBytes() + data.readableBytes()) > remaining; + } else { + int remaining = + mtu - (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET); + + return data.readableBytes() > remaining; + } + } + + static ByteBuf encodeFollowsFragment( + ByteBufAllocator allocator, + int mtu, + int streamId, + boolean complete, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + boolean follows = data.isReadable() || metadata.isReadable(); + return PayloadFrameCodec.encode( + allocator, streamId, follows, (!follows && complete), true, metadataFragment, dataFragment); + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + FrameType frameType, + int streamId, + boolean hasMetadata, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = hasMetadata ? Unpooled.EMPTY_BUFFER : null; + if (hasMetadata) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + if (metadata.isReadable()) { + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + case REQUEST_FNF: + return RequestFireAndForgetFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_RESPONSE: + return RequestResponseFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + // Payload and synthetic types from the responder side + case PAYLOAD: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, false, metadataFragment, dataFragment); + case NEXT: + // see https://github.com/rsocket/rsocket/blob/master/Protocol.md#handling-the-unexpected + // point 7 + case NEXT_COMPLETE: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, true, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + long initialRequestN, + FrameType frameType, + int streamId, + boolean hasMetadata, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length bytes + initial requestN bytes + int remaining = mtu - FRAME_OFFSET_WITH_INITIAL_REQUEST_N; + + ByteBuf metadataFragment = hasMetadata ? Unpooled.EMPTY_BUFFER : null; + if (hasMetadata) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + if (metadata.isReadable()) { + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + // Requester Side + case REQUEST_STREAM: + return RequestStreamFrameCodec.encode( + allocator, streamId, true, initialRequestN, metadataFragment, dataFragment); + case REQUEST_CHANNEL: + return RequestChannelFrameCodec.encode( + allocator, streamId, true, false, initialRequestN, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static int assertMtu(int mtu) { + if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } else { + return mtu; + } + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientServerTest.java b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java similarity index 57% rename from rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientServerTest.java rename to rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java index 9ddfccf1f..6d1ee1b09 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientServerTest.java +++ b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.transport.local; +package io.rsocket.core; -import io.rsocket.test.BaseClientServerTest; +import io.netty.buffer.ByteBuf; -public class LocalClientServerTest extends BaseClientServerTest { +interface FrameHandler { - @Override - protected LocalClientSetupRule createClientServer() { - return new LocalClientSetupRule(); - } + void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload); + + void handleError(Throwable t); + + void handleComplete(); + + void handleCancel(); + + void handleRequestN(long n); } diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java new file mode 100644 index 000000000..03ab7c257 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java @@ -0,0 +1,20 @@ +package io.rsocket.core; + +/** Handler which enables async lease permits issuing */ +interface LeasePermitHandler { + + /** + * Called by {@link RequesterLeaseTracker} when there is an available lease + * + * @return {@code true} to indicate that lease permit was consumed successfully + */ + boolean handlePermit(); + + /** + * Called by {@link RequesterLeaseTracker} when there are no lease permit available at the moment + * and the list of awaiting {@link LeasePermitHandler} reached the configured limit + * + * @param t associated lease permit rejection exception + */ + void handlePermitError(Throwable t); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java new file mode 100644 index 000000000..ad4b36e3a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.rsocket.lease.LeaseSender; +import reactor.core.publisher.Flux; + +public final class LeaseSpec { + + LeaseSender sender = Flux::never; + int maxPendingRequests = 256; + + LeaseSpec() {} + + public LeaseSpec sender(LeaseSender sender) { + this.sender = sender; + return this; + } + + /** + * Setup the maximum queued requests waiting for lease to be available. The default value is 256 + * + * @param maxPendingRequests if set to 0 the requester will terminate the request immediately if + * no leases is available + */ + public LeaseSpec maxPendingRequests(int maxPendingRequests) { + this.maxPendingRequests = maxPendingRequests; + return this; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java new file mode 100644 index 000000000..7b5d8f6c2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java @@ -0,0 +1,72 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameUtil; +import java.net.SocketAddress; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +class LoggingDuplexConnection implements DuplexConnection { + + private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); + + final DuplexConnection source; + + LoggingDuplexConnection(DuplexConnection source) { + this.source = source; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + LOGGER.debug("sending -> " + FrameUtil.toString(frame)); + + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + LOGGER.debug("sending -> " + e.getClass().getSimpleName() + ": " + e.getMessage()); + + source.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return source + .receive() + .doOnNext(frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + static DuplexConnection wrapIfEnabled(DuplexConnection source) { + if (LOGGER.isDebugEnabled()) { + return new LoggingDuplexConnection(source); + } + + return source; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java new file mode 100644 index 000000000..e2512e995 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java @@ -0,0 +1,190 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValidMetadata; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.MetadataPushFrameCodec; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class MetadataPushRequesterMono extends Mono implements Scannable { + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(MetadataPushRequesterMono.class, "state"); + + final ByteBufAllocator allocator; + final Payload payload; + final int maxFrameLength; + final DuplexConnection connection; + + MetadataPushRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.connection = requesterResponderSupport.getDuplexConnection(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("MetadataPushMono allows only a single Subscriber")); + return; + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (!hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException("Metadata push should have metadata field present")); + return; + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + Operators.error(actual, e); + return; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.connection.sendFrame(0, requestFrame); + + Operators.complete(actual); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + throw new IllegalStateException("MetadataPushMono allows only a single Subscriber"); + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (!hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException("Metadata push should have metadata field present"); + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + throw e; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.connection.sendFrame(0, requestFrame); + + return null; + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(MetadataPushMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java new file mode 100644 index 000000000..4c69934e8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; + +final class MetadataPushResponderSubscriber implements CoreSubscriber { + static final Logger logger = LoggerFactory.getLogger(MetadataPushResponderSubscriber.class); + + static final MetadataPushResponderSubscriber INSTANCE = new MetadataPushResponderSubscriber(); + + private MetadataPushResponderSubscriber() {} + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + logger.debug("Dropped error", t); + } + + @Override + public void onComplete() {} +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java new file mode 100644 index 000000000..6ece319c9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -0,0 +1,76 @@ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_INITIAL_REQUEST_N; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; + +final class PayloadValidationUtils { + static final String INVALID_PAYLOAD_ERROR_MESSAGE = + "The payload is too big to be send as a single frame with a max frame length %s. Consider enabling fragmentation."; + + static boolean isValid(int mtu, int maxFrameLength, Payload payload, boolean hasInitialRequestN) { + + if (mtu > 0) { + return true; + } + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf data = payload.data(); + + int unitSize; + if (hasMetadata) { + final ByteBuf metadata = payload.metadata(); + unitSize = + (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA) + + metadata.readableBytes() + + // metadata payload bytes + data.readableBytes(); // data payload bytes + } else { + unitSize = + (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET) + + data.readableBytes(); // data payload bytes + } + + return unitSize <= maxFrameLength; + } + + static boolean isValidMetadata(int maxFrameLength, ByteBuf metadata) { + return FRAME_OFFSET + metadata.readableBytes() <= maxFrameLength; + } + + static void assertValidateSetup(int maxFrameLength, int maxInboundPayloadSize, int mtu) { + + if (maxFrameLength > FRAME_LENGTH_MASK) { + throw new IllegalArgumentException( + "Configured maxFrameLength[" + + maxFrameLength + + "] exceeds maxFrameLength limit " + + FRAME_LENGTH_MASK); + } + + if (maxFrameLength > maxInboundPayloadSize) { + throw new IllegalArgumentException( + "Configured maxFrameLength[" + + maxFrameLength + + "] exceeds maxPayloadSize[" + + maxInboundPayloadSize + + "]"); + } + + if (mtu != 0 && mtu > maxFrameLength) { + throw new IllegalArgumentException( + "Configured maximumTransmissionUnit[" + + mtu + + "] exceeds configured maxFrameLength[" + + maxFrameLength + + "]"); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java new file mode 100644 index 000000000..32e3c229d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java @@ -0,0 +1,153 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import sun.reflect.generics.reflectiveObjects.NotImplementedException; + +/** + * Contract for performing RSocket requests. + * + *

{@link RSocketClient} differs from {@link RSocket} in a number of ways: + * + *

    + *
  • {@code RSocket} represents a "live" connection that is transient and needs to be obtained + * typically from a {@code Mono} source via {@code flatMap} or block. By contrast, + * {@code RSocketClient} is a higher level layer that contains such a {@link #source() source} + * of connections and transparently obtains and re-obtains a shared connection as needed when + * requests are made concurrently. That means an {@code RSocketClient} can simply be created + * once, even before a connection is established, and shared as a singleton across multiple + * places as you would with any other client. + *
  • For request input {@code RSocket} accepts an instance of {@code Payload} and does not allow + * more than one subscription per request because there is no way to safely re-use that input. + * By contrast {@code RSocketClient} accepts {@code Publisher} and allow + * re-subscribing which repeats the request. + *
  • {@code RSocket} can be used for sending and it can also be implemented for receiving. By + * contrast {@code RSocketClient} is used only for sending, typically from the client side + * which allows obtaining and re-obtaining connections from a source as needed. However it can + * also be used from the server side by {@link #from(RSocket) wrapping} the "live" {@code + * RSocket} for a given connection. + *
+ * + *

The example below shows how to create an {@code RSocketClient}: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ *
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + *

The below configures retry logic to use when a shared {@code RSocket} connection is obtained: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ *
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + * @since 1.1 + * @see io.rsocket.loadbalance.LoadbalanceRSocketClient + */ +public interface RSocketClient extends Closeable { + + /** + * Connect to the remote rsocket endpoint, if not yet connected. This method is a shortcut for + * {@code RSocketClient#source().subscribe()}. + * + * @return {@code true} if an attempt to connect was triggered or if already connected, or {@code + * false} if the client is terminated. + */ + default boolean connect() { + throw new NotImplementedException(); + } + + default Mono onClose() { + return Mono.error(new NotImplementedException()); + } + + /** Return the underlying source used to obtain a shared {@link RSocket} connection. */ + Mono source(); + + /** + * Perform a Fire-and-Forget interaction via {@link RSocket#fireAndForget(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Mono fireAndForget(Mono payloadMono); + + /** + * Perform a Request-Response interaction via {@link RSocket#requestResponse(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Mono requestResponse(Mono payloadMono); + + /** + * Perform a Request-Stream interaction via {@link RSocket#requestStream(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Flux requestStream(Mono payloadMono); + + /** + * Perform a Request-Channel interaction via {@link RSocket#requestChannel(Publisher)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Flux requestChannel(Publisher payloads); + + /** + * Perform a Metadata Push via {@link RSocket#metadataPush(Payload)}. Allows multiple + * subscriptions and performs a request per subscriber. + */ + Mono metadataPush(Mono payloadMono); + + /** + * Create an {@link RSocketClient} that obtains shared connections as needed, when requests are + * made, from the given {@code Mono} source. + * + * @param source the source for connections, typically prepared via {@link RSocketConnector}. + * @return the created client instance + */ + static RSocketClient from(Mono source) { + return new DefaultRSocketClient(source); + } + + /** + * Adapt the given {@link RSocket} to use as {@link RSocketClient}. This is useful to wrap the + * sending {@code RSocket} in a server. + * + *

Note: unlike an {@code RSocketClient} created via {@link + * RSocketClient#from(Mono)}, the instance returned from this factory method can only perform + * requests for as long as the given {@code RSocket} remains "live". + * + * @param rsocket the {@code RSocket} to perform requests with + * @return the created client instance + */ + static RSocketClient from(RSocket rsocket) { + return new RSocketClientAdapter(rsocket); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java new file mode 100644 index 000000000..ae8b7da97 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java @@ -0,0 +1,88 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Simple adapter from {@link RSocket} to {@link RSocketClient}. This is useful in code that needs + * to deal with both in the same way. When connecting to a server, typically {@link RSocketClient} + * is expected to be used, but in a responder (client or server), it is necessary to interact with + * {@link RSocket} to make requests to the remote end. + * + * @since 1.1 + */ +class RSocketClientAdapter implements RSocketClient { + + private final RSocket rsocket; + + public RSocketClientAdapter(RSocket rsocket) { + this.rsocket = rsocket; + } + + public RSocket rsocket() { + return rsocket; + } + + @Override + public boolean connect() { + throw new UnsupportedOperationException("Connect does not apply to a server side RSocket"); + } + + @Override + public Mono source() { + return Mono.just(rsocket); + } + + @Override + public Mono onClose() { + return rsocket.onClose(); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return payloadMono.flatMap(rsocket::fireAndForget); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return payloadMono.flatMap(rsocket::requestResponse); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return payloadMono.flatMapMany(rsocket::requestStream); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return rsocket.requestChannel(payloads); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return payloadMono.flatMap(rsocket::metadataPush); + } + + @Override + public void dispose() { + rsocket.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java new file mode 100644 index 000000000..de494c4e3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -0,0 +1,746 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.assertMtu; +import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.resume.ClientRSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumableFramesStore; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +/** + * The main class to use to establish a connection to an RSocket server. + * + *

For using TCP using default settings: + * + *

{@code
+ * import io.rsocket.transport.netty.client.TcpClientTransport;
+ *
+ * Mono source =
+ *         RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + *

To customize connection settings before connecting: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ */ +public class RSocketConnector { + private static final String CLIENT_TAG = "client"; + + private static final BiConsumer INVALIDATE_FUNCTION = + (r, i) -> r.onClose().subscribe(null, __ -> i.invalidate(), i::invalidate); + + private Mono setupPayloadMono = Mono.empty(); + private String metadataMimeType = "application/binary"; + private String dataMimeType = "application/binary"; + private Duration keepAliveInterval = Duration.ofSeconds(20); + private Duration keepAliveMaxLifeTime = Duration.ofSeconds(90); + + @Nullable private SocketAcceptor acceptor; + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Retry retrySpec; + private Resume resume; + + @Nullable private Consumer leaseConfigurer; + + private int mtu = 0; + private int maxInboundPayloadSize = Integer.MAX_VALUE; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private RSocketConnector() {} + + /** + * Static factory method to create an {@code RSocketConnector} instance and customize default + * settings before connecting. To connect only, use {@link #connectWith(ClientTransport)}. + */ + public static RSocketConnector create() { + return new RSocketConnector(); + } + + /** + * Static factory method to connect with default settings, effectively a shortcut for: + * + *
+   * RSocketConnector.create().connect(transport);
+   * 
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public static Mono connectWith(ClientTransport transport) { + return RSocketConnector.create().connect(() -> transport); + } + + /** + * Provide a {@code Mono} from which to obtain the {@code Payload} for the initial SETUP frame. + * Data and metadata should be formatted according to the MIME types specified via {@link + * #dataMimeType(String)} and {@link #metadataMimeType(String)}. + * + * @param setupPayloadMono the payload with data and/or metadata for the {@code SETUP} frame. + * @return the same instance for method chaining + * @since 1.0.2 + * @see
SETUP + * Frame + */ + public RSocketConnector setupPayload(Mono setupPayloadMono) { + this.setupPayloadMono = setupPayloadMono; + return this; + } + + /** + * Variant of {@link #setupPayload(Mono)} that accepts a {@code Payload} instance. + * + *

Note: if the given payload is {@link io.rsocket.util.ByteBufPayload}, it is copied to a + * {@link DefaultPayload} and released immediately. This ensures it can re-used to obtain a + * connection more than once. + * + * @param payload the payload with data and/or metadata for the {@code SETUP} frame. + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector setupPayload(Payload payload) { + if (payload instanceof DefaultPayload) { + this.setupPayloadMono = Mono.just(payload); + } else { + this.setupPayloadMono = Mono.just(DefaultPayload.create(Objects.requireNonNull(payload))); + payload.release(); + } + return this; + } + + /** + * Set the MIME type to use for formatting payload data on the established connection. This is set + * in the initial {@code SETUP} frame sent to the server. + * + *

By default this is set to {@code "application/binary"}. + * + * @param dataMimeType the MIME type to be used for payload data + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector dataMimeType(String dataMimeType) { + this.dataMimeType = Objects.requireNonNull(dataMimeType); + return this; + } + + /** + * Set the MIME type to use for formatting payload metadata on the established connection. This is + * set in the initial {@code SETUP} frame sent to the server. + * + *

For metadata encoding, consider using one of the following encoders: + * + *

    + *
  • {@link io.rsocket.metadata.CompositeMetadataCodec Composite Metadata} + *
  • {@link io.rsocket.metadata.TaggingMetadataCodec Routing} + *
  • {@link io.rsocket.metadata.AuthMetadataCodec Authentication} + *
+ * + *

For more on the above metadata formats, see the corresponding protocol extensions + * + *

By default this is set to {@code "application/binary"}. + * + * @param metadataMimeType the MIME type to be used for payload metadata + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector metadataMimeType(String metadataMimeType) { + this.metadataMimeType = Objects.requireNonNull(metadataMimeType); + return this; + } + + /** + * Set the "Time Between {@code KEEPALIVE} Frames" which is how frequently {@code KEEPALIVE} + * frames should be emitted, and the "Max Lifetime" which is how long to allow between {@code + * KEEPALIVE} frames from the remote end before concluding that connectivity is lost. Both + * settings are specified in the initial {@code SETUP} frame sent to the server. The spec mentions + * the following: + * + *

    + *
  • For server-to-server connections, a reasonable time interval between client {@code + * KEEPALIVE} frames is 500ms. + *
  • For mobile-to-server connections, the time interval between client {@code KEEPALIVE} + * frames is often {@code >} 30,000ms. + *
+ * + *

By default these are set to 20 seconds and 90 seconds respectively. + * + * @param interval how frequently to emit KEEPALIVE frames + * @param maxLifeTime how long to allow between {@code KEEPALIVE} frames from the remote end + * before assuming that connectivity is lost; the value should be generous and allow for + * multiple missed {@code KEEPALIVE} frames. + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector keepAlive(Duration interval, Duration maxLifeTime) { + if (!interval.negated().isNegative()) { + throw new IllegalArgumentException("`interval` for keepAlive must be > 0"); + } + if (!maxLifeTime.negated().isNegative()) { + throw new IllegalArgumentException("`maxLifeTime` for keepAlive must be > 0"); + } + this.keepAliveInterval = interval; + this.keepAliveMaxLifeTime = maxLifeTime; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketConnector interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Configure a client-side {@link SocketAcceptor} for responding to requests from the server. + * + *

A full-form example with access to the {@code SETUP} frame and the "sending" RSocket (the + * same as the one returned from {@link #connect(ClientTransport)}): + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example with just the handling RSocket: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.with(new RSocket() {...})))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example handling only request-response: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.forRequestResponse(payload -> ...))
+   *             .connect(transport);
+   * }
+ * + *

By default, {@code new RSocket(){}} is used which rejects all requests from the server with + * {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to use for responding to server requests + * @return the same instance for method chaining + */ + public RSocketConnector acceptor(SocketAcceptor acceptor) { + this.acceptor = acceptor; + return this; + } + + /** + * When this is enabled, the connect methods of this class return a special {@code Mono} + * that maintains a single, shared {@code RSocket} for all subscribers: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   * }
+ * + *

The {@code RSocket} remains cached until the connection is lost and after that, new attempts + * to subscribe or re-subscribe trigger a reconnect and result in a new shared {@code RSocket}: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  r1.dispose();
+   *
+   *  RSocket r3 = rsocketMono.block();
+   *  RSocket r4 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   *  assert r3 == r4;
+   *  assert r1 != r3;
+   *
+   * }
+ * + *

Downstream subscribers for individual requests still need their own retry logic to determine + * if or when failed requests should be retried which in turn triggers the shared reconnect: + * + *

{@code
+   * Mono rocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  rsocketMono.flatMap(rsocket -> rsocket.requestResponse(...))
+   *           .retryWhen(Retry.fixedDelay(1, Duration.ofSeconds(5)))
+   *           .subscribe()
+   * }
+ * + *

Note: this feature is mutually exclusive with {@link #resume(Resume)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled in which case a new connection is obtained per subscriber. + * + * @param retry a retry spec that declares the rules for reconnecting + * @return the same instance for method chaining + */ + public RSocketConnector reconnect(Retry retry) { + this.retrySpec = Objects.requireNonNull(retry); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are resumed automatically. For this + * to work the server must also support and have the Resume capability enabled. + * + *

See {@link Resume} for settings to customize the Resume capability. + * + *

Note: this feature is mutually exclusive with {@link #reconnect(Retry)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketConnector resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

Example usage: + * + *

{@code
+   * Mono rocketMono =
+   *         RSocketConnector.create()
+   *                         .lease()
+   *                         .connect(transport);
+   * }
+ * + *

By default this is not enabled. + * + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease() { + return lease((config -> {})); + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

Example usage: + * + *

{@code
+   * Mono rocketMono =
+   *         RSocketConnector.create()
+   *                         .lease(spec -> spec.maxPendingRequests(128))
+   *                         .connect(transport);
+   * }
+ * + *

By default this is not enabled. + * + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; + return this; + } + + /** + * When this is set, frames reassembler control maximum payload size which can be reassembled. + * + *

By default this is not set in which case maximum reassembled payloads size is not + * controlled. + * + * @param maxInboundPayloadSize the threshold size for reassembly, must no be less than 64 bytes. + * Please note, {@code maxInboundPayloadSize} must always be greater or equal to {@link + * io.rsocket.transport.Transport#maxFrameLength()}, otherwise inbound frame can exceed the + * {@code maxInboundPayloadSize} + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketConnector maxInboundPayloadSize(int maxInboundPayloadSize) { + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * broken down into fragments to fit that size. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketConnector fragment(int mtu) { + this.mtu = assertMtu(mtu); + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketConnector payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * Connect with the given transport and obtain a live {@link RSocket} to use for making requests. + * Each subscriber to the returned {@code Mono} receives a new connection, if neither {@link + * #reconnect(Retry) reconnect} nor {@link #resume(Resume)} are enabled. + * + *

The following transports are available through additional RSocket Java modules: + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpClientTransport TcpClientTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketClientTransport + * WebsocketClientTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalClientTransport LocalClientTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(ClientTransport transport) { + return connect(() -> transport); + } + + /** + * Variant of {@link #connect(ClientTransport)} with a {@link Supplier} for the {@code + * ClientTransport}. + * + *

// TODO: when to use? + * + * @param transportSupplier supplier for the transport to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(Supplier transportSupplier) { + return Mono.fromSupplier(transportSupplier) + .flatMap( + ct -> { + int maxFrameLength = ct.maxFrameLength(); + + Mono connectionMono = + Mono.fromCallable( + () -> { + assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); + return ct; + }) + .flatMap(transport -> transport.connect()) + .map( + sourceConnection -> + interceptors.initConnection( + DuplexConnectionInterceptor.Type.SOURCE, sourceConnection)) + .map(source -> LoggingDuplexConnection.wrapIfEnabled(source)); + + return connectionMono + .flatMap( + connection -> + setupPayloadMono + .defaultIfEmpty(EmptyPayload.INSTANCE) + .map(setupPayload -> Tuples.of(connection, setupPayload)) + .doOnError(ex -> connection.dispose()) + .doOnCancel(connection::dispose)) + .flatMap( + tuple2 -> { + DuplexConnection sourceConnection = tuple2.getT1(); + Payload setupPayload = tuple2.getT2(); + boolean leaseEnabled = leaseConfigurer != null; + boolean resumeEnabled = resume != null; + // TODO: add LeaseClientSetup + ClientSetup clientSetup = new DefaultClientSetup(); + ByteBuf resumeToken; + + if (resumeEnabled) { + resumeToken = resume.getTokenSupplier().get(); + } else { + resumeToken = Unpooled.EMPTY_BUFFER; + } + + ByteBuf setupFrame = + SetupFrameCodec.encode( + sourceConnection.alloc(), + leaseEnabled, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + resumeToken, + metadataMimeType, + dataMimeType, + setupPayload); + + sourceConnection.sendFrame(0, setupFrame.retainedSlice()); + + return clientSetup + .init(sourceConnection) + .flatMap( + tuple -> { + // should be used if lease setup sequence; + // See: + // https://github.com/rsocket/rsocket/blob/master/Protocol.md#sequences-with-lease + final ByteBuf serverResponse = tuple.getT1(); + final DuplexConnection clientServerConnection = tuple.getT2(); + final KeepAliveHandler keepAliveHandler; + final DuplexConnection wrappedConnection; + final InitializingInterceptorRegistry interceptors = + this.interceptors; + + if (resumeEnabled) { + final ResumableFramesStore resumableFramesStore = + resume.getStoreFactory(CLIENT_TAG).apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + CLIENT_TAG, + resumeToken, + clientServerConnection, + resumableFramesStore); + final ResumableClientSetup resumableClientSetup = + new ResumableClientSetup(); + final ClientRSocketSession session = + new ClientRSocketSession( + resumeToken, + resumableDuplexConnection, + connectionMono, + resumableClientSetup::init, + resumableFramesStore, + resume.getSessionDuration(), + resume.getRetry(), + resume.isCleanupStoreOnKeepAlive()); + keepAliveHandler = + new KeepAliveHandler.ResumableKeepAliveHandler( + resumableDuplexConnection, session, session); + wrappedConnection = resumableDuplexConnection; + } else { + keepAliveHandler = + new KeepAliveHandler.DefaultKeepAliveHandler(); + wrappedConnection = clientServerConnection; + } + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer( + wrappedConnection, interceptors, true); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker( + CLIENT_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } + + final Sinks.Empty requesterOnAllClosedSink = + Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = + Sinks.unsafe().empty(); + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + StreamIdSupplier.clientSupplier(), + mtu, + maxFrameLength, + maxInboundPayloadSize, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + keepAliveHandler, + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), + requesterOnAllClosedSink.asMono())); + + RSocket wrappedRSocketRequester = + interceptors.initRequester(rSocketRequester); + + SocketAcceptor acceptor = + this.acceptor != null + ? this.acceptor + : SocketAcceptor.with(new RSocket() {}); + + ConnectionSetupPayload setup = + new DefaultConnectionSetupPayload(setupFrame); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setup, wrappedRSocketRequester) + .map( + rSocketHandler -> { + RSocket wrappedRSocketHandler = + interceptors.initResponder(rSocketHandler); + + ResponderLeaseTracker responderLeaseTracker = + leaseEnabled + ? new ResponderLeaseTracker( + CLIENT_TAG, + wrappedConnection, + leases.sender) + : null; + + RSocket rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + wrappedRSocketHandler, + payloadDecoder, + responderLeaseTracker, + mtu, + maxFrameLength, + maxInboundPayloadSize, + leaseEnabled + && leases.sender + instanceof TrackingLeaseSender + ? rSocket -> + interceptors + .initResponderRequestInterceptor( + rSocket, + (RequestInterceptor) + leases.sender) + : interceptors + ::initResponderRequestInterceptor, + responderOnAllClosedSink); + + return wrappedRSocketRequester; + }) + .doFinally(signalType -> setup.release()); + }); + }); + }) + .as( + source -> { + if (retrySpec != null) { + return new ReconnectMono<>( + source.retryWhen(retrySpec), Disposable::dispose, INVALIDATE_FUNCTION); + } else { + return source; + } + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java new file mode 100644 index 000000000..b8a9c00ff --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -0,0 +1,445 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; + +import io.netty.buffer.ByteBuf; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.keepalive.KeepAliveFramesAcceptor; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.plugins.RequestInterceptor; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** + * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer + */ +class RSocketRequester extends RequesterResponderSupport implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketRequester.class); + + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + static { + CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); + } + + private volatile Throwable terminationError; + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketRequester.class, Throwable.class, "terminationError"); + + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + private final Sinks.Empty onThisSideClosedSink; + private final Mono onAllClosed; + private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; + + RSocketRequester( + DuplexConnection connection, + PayloadDecoder payloadDecoder, + StreamIdSupplier streamIdSupplier, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + int keepAliveTickPeriod, + int keepAliveAckTimeout, + @Nullable KeepAliveHandler keepAliveHandler, + Function requestInterceptorFunction, + @Nullable RequesterLeaseTracker requesterLeaseTracker, + Sinks.Empty onThisSideClosedSink, + Mono onAllClosed) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + streamIdSupplier, + requestInterceptorFunction); + + this.requesterLeaseTracker = requesterLeaseTracker; + this.onThisSideClosedSink = onThisSideClosedSink; + this.onAllClosed = onAllClosed; + + // DO NOT Change the order here. The Send processor must be subscribed to before receiving + connection.onClose().subscribe(null, this::tryShutdown, this::tryShutdown); + + connection.receive().subscribe(this::handleIncomingFrames, e -> {}); + + if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { + KeepAliveSupport keepAliveSupport = + new ClientKeepAliveSupport(this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout); + this.keepAliveFramesAcceptor = + keepAliveHandler.start( + keepAliveSupport, + (keepAliveFrame) -> connection.sendFrame(0, keepAliveFrame), + this::tryTerminateOnKeepAlive); + } else { + keepAliveFramesAcceptor = null; + } + } + + @Override + public Mono fireAndForget(Payload payload) { + if (this.requesterLeaseTracker == null) { + return new FireAndForgetRequesterMono(payload, this); + } else { + return new SlowFireAndForgetRequesterMono(payload, this); + } + } + + @Override + public Mono requestResponse(Payload payload) { + return new RequestResponseRequesterMono(payload, this); + } + + @Override + public Flux requestStream(Payload payload) { + return new RequestStreamRequesterFlux(payload, this); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new RequestChannelRequesterFlux(payloads, this); + } + + @Override + public Mono metadataPush(Payload payload) { + Throwable terminationError = this.terminationError; + if (terminationError != null) { + payload.release(); + return Mono.error(terminationError); + } + + return new MetadataPushRequesterMono(payload, this); + } + + @Override + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; + } + + @Override + public int getNextStreamId() { + int nextStreamId = super.getNextStreamId(); + + Throwable terminationError = this.terminationError; + if (terminationError != null) { + throw reactor.core.Exceptions.propagate(terminationError); + } + + return nextStreamId; + } + + @Override + public int addAndGetNextStreamId(FrameHandler frameHandler) { + int nextStreamId = super.addAndGetNextStreamId(frameHandler); + + Throwable terminationError = this.terminationError; + if (terminationError != null) { + super.remove(nextStreamId, frameHandler); + throw reactor.core.Exceptions.propagate(terminationError); + } + + return nextStreamId; + } + + @Override + public double availability() { + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + return Math.min(getDuplexConnection().availability(), requesterLeaseTracker.availability()); + } else { + return getDuplexConnection().availability(); + } + } + + @Override + public void dispose() { + if (terminationError != null) { + return; + } + + getDuplexConnection().sendErrorAndClose(new ConnectionErrorException("Disposed")); + } + + @Override + public boolean isDisposed() { + return terminationError != null; + } + + @Override + public Mono onClose() { + return onAllClosed; + } + + private void handleIncomingFrames(ByteBuf frame) { + try { + int streamId = FrameHeaderCodec.streamId(frame); + FrameType type = FrameHeaderCodec.frameType(frame); + if (streamId == 0) { + handleStreamZero(type, frame); + } else { + handleFrame(streamId, type, frame); + } + } catch (Throwable t) { + LOGGER.error("Unexpected error during frame handling", t); + final ConnectionErrorException error = + new ConnectionErrorException("Unexpected error during frame handling", t); + getDuplexConnection().sendErrorAndClose(error); + } + } + + private void handleStreamZero(FrameType type, ByteBuf frame) { + switch (type) { + case ERROR: + tryTerminateOnZeroError(frame); + break; + case LEASE: + requesterLeaseTracker.handleLeaseFrame(frame); + break; + case KEEPALIVE: + if (keepAliveFramesAcceptor != null) { + keepAliveFramesAcceptor.receive(frame); + } + break; + default: + // Ignore unknown frames. Throwing an error will close the socket. + if (LOGGER.isInfoEnabled()) { + LOGGER.info("Requester received unsupported frame on stream 0: " + frame.toString()); + } + } + } + + private void handleFrame(int streamId, FrameType type, ByteBuf frame) { + FrameHandler receiver = this.get(streamId); + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + + switch (type) { + case NEXT_COMPLETE: + receiver.handleNext(frame, false, true); + break; + case NEXT: + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); + break; + case COMPLETE: + receiver.handleComplete(); + break; + case ERROR: + receiver.handleError(Exceptions.from(streamId, frame)); + break; + case CANCEL: + receiver.handleCancel(); + break; + case REQUEST_N: + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + break; + default: + throw new IllegalStateException( + "Requester received unsupported frame on stream " + streamId + ": " + frame.toString()); + } + } + + @SuppressWarnings("ConstantConditions") + private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { + if (!super.streamIdSupplier.isBeforeOrCurrent(streamId)) { + if (type == FrameType.ERROR) { + // message for stream that has never existed, we have a problem with + // the overall connection and must tear down + String errorMessage = ErrorFrameCodec.dataUtf8(frame); + + throw new IllegalStateException( + "Client received error for non-existent stream: " + + streamId + + " Message: " + + errorMessage); + } else { + throw new IllegalStateException( + "Client received message for non-existent stream: " + + streamId + + ", frame type: " + + type); + } + } + // receiving a frame after a given stream has been cancelled/completed, + // so ignore (cancellation is async so there is a race condition) + } + + private void tryTerminateOnKeepAlive(KeepAliveSupport.KeepAlive keepAlive) { + tryTerminate( + () -> + new ConnectionErrorException( + String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); + getDuplexConnection().dispose(); + } + + private void tryShutdown(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.info( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } + + private void tryTerminateOnZeroError(ByteBuf errorFrame) { + tryTerminate(() -> Exceptions.from(0, errorFrame)); + } + + private void tryTerminate(Supplier errorSupplier) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(e); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } + + private void tryShutdown() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } + + private void terminate(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing requester " + getDuplexConnection() + " due to " + e); + } + if (keepAliveFramesAcceptor != null) { + keepAliveFramesAcceptor.dispose(); + } + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + requesterLeaseTracker.dispose(e); + } + + final Collection activeStreamsCopy; + synchronized (this) { + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } + + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + try { + handler.handleError(e); + } catch (Throwable ignored) { + } + } + } + + if (e == CLOSED_CHANNEL_EXCEPTION) { + onThisSideClosedSink.tryEmitEmpty(); + } else { + onThisSideClosedSink.tryEmitError(e); + } + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("requester closed " + getDuplexConnection()); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java new file mode 100644 index 000000000..50c5ba54c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ +class RSocketResponder extends RequesterResponderSupport implements RSocket { + + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketResponder.class); + + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + private final RSocket requestHandler; + private final Sinks.Empty onThisSideClosedSink; + + @Nullable private final ResponderLeaseTracker leaseHandler; + + private volatile Throwable terminationError; + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketResponder.class, Throwable.class, "terminationError"); + + RSocketResponder( + DuplexConnection connection, + RSocket requestHandler, + PayloadDecoder payloadDecoder, + @Nullable ResponderLeaseTracker leaseHandler, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + Function requestInterceptorFunction, + Sinks.Empty onThisSideClosedSink) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + null, + requestInterceptorFunction); + + this.requestHandler = requestHandler; + + this.leaseHandler = leaseHandler; + this.onThisSideClosedSink = onThisSideClosedSink; + + connection + .onClose() + .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); + + connection.receive().subscribe(this::handleFrame, e -> {}); + } + + private void tryTerminateOnConnectionError(Throwable e) { + if (LOGGER.isDebugEnabled()) { + + LOGGER.debug("Try terminate connection on responder side"); + } + tryTerminate(() -> e); + } + + private void tryTerminateOnConnectionClose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.info("Try terminate connection on responder side"); + } + tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); + } + + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + doOnDispose(); + } + } + } + + @Override + public Mono fireAndForget(Payload payload) { + try { + return requestHandler.fireAndForget(payload); + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public Mono requestResponse(Payload payload) { + try { + return requestHandler.requestResponse(payload); + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public Flux requestStream(Payload payload) { + try { + return requestHandler.requestStream(payload); + } catch (Throwable t) { + return Flux.error(t); + } + } + + @Override + public Flux requestChannel(Publisher payloads) { + try { + return requestHandler.requestChannel(payloads); + } catch (Throwable t) { + return Flux.error(t); + } + } + + @Override + public Mono metadataPush(Payload payload) { + try { + return requestHandler.metadataPush(payload); + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public void dispose() { + tryTerminate(() -> new CancellationException("Disposed")); + } + + @Override + public boolean isDisposed() { + return getDuplexConnection().isDisposed(); + } + + @Override + public Mono onClose() { + return getDuplexConnection().onClose(); + } + + final void doOnDispose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing responder " + getDuplexConnection()); + } + cleanUpSendingSubscriptions(); + + getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + + final ResponderLeaseTracker handler = leaseHandler; + if (handler != null) { + handler.dispose(); + } + + requestHandler.dispose(); + onThisSideClosedSink.tryEmitEmpty(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("responder closed " + getDuplexConnection()); + } + } + + private void cleanUpSendingSubscriptions() { + final Collection activeStreamsCopy; + synchronized (this) { + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } + + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + handler.handleCancel(); + } + } + } + + final void handleFrame(ByteBuf frame) { + try { + int streamId = FrameHeaderCodec.streamId(frame); + FrameHandler receiver; + FrameType frameType = FrameHeaderCodec.frameType(frame); + switch (frameType) { + case REQUEST_FNF: + handleFireAndForget(streamId, frame); + break; + case REQUEST_RESPONSE: + handleRequestResponse(streamId, frame); + break; + case REQUEST_STREAM: + long streamInitialRequestN = RequestStreamFrameCodec.initialRequestN(frame); + handleStream(streamId, frame, streamInitialRequestN); + break; + case REQUEST_CHANNEL: + long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame); + handleChannel( + streamId, frame, channelInitialRequestN, FrameHeaderCodec.hasComplete(frame)); + break; + case METADATA_PUSH: + handleMetadataPush(metadataPush(super.getPayloadDecoder().apply(frame))); + break; + case CANCEL: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleCancel(); + } + break; + case REQUEST_N: + receiver = super.get(streamId); + if (receiver != null) { + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + } + break; + case PAYLOAD: + // TODO: Hook in receiving socket. + break; + case NEXT: + receiver = super.get(streamId); + if (receiver != null) { + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); + } + break; + case COMPLETE: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleComplete(); + } + break; + case ERROR: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleError(io.rsocket.exceptions.Exceptions.from(streamId, frame)); + } + break; + case NEXT_COMPLETE: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleNext(frame, false, true); + } + break; + case SETUP: + getDuplexConnection() + .sendFrame( + streamId, + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException("Setup frame received post setup."))); + break; + case LEASE: + default: + getDuplexConnection() + .sendFrame( + streamId, + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException( + "ServerRSocket: Unexpected frame type: " + frameType))); + break; + } + } catch (Throwable t) { + LOGGER.error("Unexpected error during frame handling", t); + getDuplexConnection() + .sendFrame( + 0, + ErrorFrameCodec.encode( + super.getAllocator(), + 0, + new ConnectionErrorException("Unexpected error during frame handling", t))); + this.tryTerminateOnConnectionError(t); + } + } + + final void handleFireAndForget(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + if (FrameHeaderCodec.hasFollows(frame)) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } + + FireAndForgetResponderSubscriber subscriber = + new FireAndForgetResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(new FireAndForgetResponderSubscriber(streamId, this)); + } else { + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + } + } + } else { + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } + } + } + + final void handleRequestResponse(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, this); + + if (this.add(streamId, subscriber)) { + this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); + } + } + + final void handleStream(int streamId, ByteBuf frame, long initialRequestN) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); + + this.add(streamId, subscriber); + } else { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + + if (this.add(streamId, subscriber)) { + this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); + } + } + + final void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); + + this.add(streamId, subscriber); + } else { + final Payload firstPayload = super.getPayloadDecoder().apply(frame); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); + + if (this.add(streamId, subscriber)) { + this.requestChannel(subscriber).subscribe(subscriber); + if (complete) { + subscriber.handleComplete(); + } + } + } + } else { + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); + } + } + + private void sendLeaseRejection(int streamId, Throwable leaseError) { + getDuplexConnection() + .sendFrame(streamId, ErrorFrameCodec.encode(getAllocator(), streamId, leaseError)); + } + + private void handleMetadataPush(Mono result) { + result.subscribe(MetadataPushResponderSubscriber.INSTANCE); + } + + @Override + public boolean add(int streamId, FrameHandler frameHandler) { + if (!super.add(streamId, frameHandler)) { + frameHandler.handleCancel(); + return false; + } + + return true; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java new file mode 100644 index 000000000..e969c39d2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -0,0 +1,523 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.assertMtu; +import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; +import io.rsocket.SocketAcceptor; +import io.rsocket.exceptions.InvalidSetupException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.resume.SessionManager; +import io.rsocket.transport.ServerTransport; +import java.time.Duration; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * The main class for starting an RSocket server. + * + *

For example: + * + *

{@code
+ * CloseableChannel closeable =
+ *         RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+ *                 .bind(TcpServerTransport.create("localhost", 7000))
+ *                 .block();
+ * }
+ */ +public final class RSocketServer { + private static final String SERVER_TAG = "server"; + + private SocketAcceptor acceptor = SocketAcceptor.with(new RSocket() {}); + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Resume resume; + private Consumer leaseConfigurer = null; + + private int mtu = 0; + private int maxInboundPayloadSize = Integer.MAX_VALUE; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + private Duration timeout = Duration.ofMinutes(1); + + private RSocketServer() {} + + /** Static factory method to create an {@code RSocketServer}. */ + public static RSocketServer create() { + return new RSocketServer(); + } + + /** + * Static factory method to create an {@code RSocketServer} instance with the given {@code + * SocketAcceptor}. Effectively a shortcut for: + * + *
+   * RSocketServer.create().acceptor(...);
+   * 
+ * + * @param acceptor the acceptor to handle connections with + * @return the same instance for method chaining + * @see #acceptor(SocketAcceptor) + */ + public static RSocketServer create(SocketAcceptor acceptor) { + return RSocketServer.create().acceptor(acceptor); + } + + /** + * Set the acceptor to handle incoming connections and handle requests. + * + *

An example with access to the {@code SETUP} frame and sending RSocket for performing + * requests back to the client if needed: + * + *

{@code
+   * RSocketServer.create((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to provide the handling RSocket only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to handle request-response interactions only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.forRequestResponse(payload -> ...))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default, {@code new RSocket(){}} is used for handling which rejects requests from the + * client with {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to handle incoming connections and requests with + * @return the same instance for method chaining + */ + public RSocketServer acceptor(SocketAcceptor acceptor) { + Objects.requireNonNull(acceptor); + this.acceptor = acceptor; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketServer interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are transparently resumed. For this + * to work clients must also support and request to enable this when connecting. + * + *

Use the {@link Resume} argument to customize the Resume session duration, storage, retry + * logic, and others. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketServer resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. For + * this to work clients must also support and request to enable this when connecting. + * + *

Example usage: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .lease(spec ->
+   *            spec.sender(() -> Flux.interval(ofSeconds(1))
+   *                                  .map(__ -> Lease.create(ofSeconds(1), 1)))
+   *         )
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default this is not enabled. + * + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketServer lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; + return this; + } + + /** + * When this is set, frames reassembler control maximum payload size which can be reassembled. + * + *

By default this is not set in which case maximum reassembled payloads size is not + * controlled. + * + * @param maxInboundPayloadSize the threshold size for reassembly, must no be less than 64 bytes. + * Please note, {@code maxInboundPayloadSize} must always be greater or equal to {@link + * io.rsocket.transport.Transport#maxFrameLength()}, otherwise inbound frame can exceed the + * {@code maxInboundPayloadSize} + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketServer maxInboundPayloadSize(int maxInboundPayloadSize) { + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); + return this; + } + + /** + * Specify the max time to wait for the first frame (e.g. {@code SETUP}) on an accepted + * connection. + * + *

By default this is set to 1 minute. + * + * @param timeout duration + * @return the same instance for method chaining + */ + public RSocketServer maxTimeToFirstFrame(Duration timeout) { + if (timeout.isNegative() || timeout.isZero()) { + throw new IllegalArgumentException("Setup Handling Timeout should be greater than zero"); + } + this.timeout = timeout; + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * fragmented. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketServer fragment(int mtu) { + this.mtu = assertMtu(mtu); + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketServer payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * Start the server on the given transport. + * + *

The following transports are available from additional RSocket Java modules: + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpServerTransport TcpServerTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketServerTransport + * WebsocketServerTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalServerTransport LocalServerTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @param the type of {@code Closeable} for the given transport + * @return a {@code Mono} with a {@code Closeable} that can be used to obtain information about + * the server, stop it, or be notified of when it is stopped. + */ + public Mono bind(ServerTransport transport) { + return Mono.defer( + new Supplier>() { + final ServerSetup serverSetup = serverSetup(timeout); + + @Override + public Mono get() { + int maxFrameLength = transport.maxFrameLength(); + assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); + return transport + .start(duplexConnection -> acceptor(serverSetup, duplexConnection, maxFrameLength)) + .doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe()); + } + }); + } + + /** + * Start the server on the given transport. Effectively is a shortcut for {@code + * .bind(ServerTransport).block()} + */ + public T bindNow(ServerTransport transport) { + return bind(transport).block(); + } + /** + * An alternative to {@link #bind(ServerTransport)} that is useful for installing RSocket on a + * server that is started independently. + * + * @see io.rsocket.examples.transport.ws.WebSocketHeadersSample + */ + public ServerTransport.ConnectionAcceptor asConnectionAcceptor() { + return asConnectionAcceptor(FRAME_LENGTH_MASK); + } + + /** + * An alternative to {@link #bind(ServerTransport)} that is useful for installing RSocket on a + * server that is started independently. + * + * @see io.rsocket.examples.transport.ws.WebSocketHeadersSample + */ + public ServerTransport.ConnectionAcceptor asConnectionAcceptor(int maxFrameLength) { + assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); + return new ServerTransport.ConnectionAcceptor() { + private final ServerSetup serverSetup = serverSetup(timeout); + + @Override + public Mono apply(DuplexConnection connection) { + return acceptor(serverSetup, connection, maxFrameLength); + } + }; + } + + private Mono acceptor( + ServerSetup serverSetup, DuplexConnection sourceConnection, int maxFrameLength) { + + final DuplexConnection interceptedConnection = + interceptors.initConnection(DuplexConnectionInterceptor.Type.SOURCE, sourceConnection); + + return serverSetup + .init(LoggingDuplexConnection.wrapIfEnabled(interceptedConnection)) + .flatMap( + tuple2 -> { + final ByteBuf startFrame = tuple2.getT1(); + final DuplexConnection clientServerConnection = tuple2.getT2(); + + return accept(serverSetup, startFrame, clientServerConnection, maxFrameLength); + }); + } + + private Mono acceptResume( + ServerSetup serverSetup, ByteBuf resumeFrame, DuplexConnection clientServerConnection) { + return serverSetup.acceptRSocketResume(resumeFrame, clientServerConnection); + } + + private Mono accept( + ServerSetup serverSetup, + ByteBuf startFrame, + DuplexConnection clientServerConnection, + int maxFrameLength) { + switch (FrameHeaderCodec.frameType(startFrame)) { + case SETUP: + return acceptSetup(serverSetup, startFrame, clientServerConnection, maxFrameLength); + case RESUME: + return acceptResume(serverSetup, startFrame, clientServerConnection); + default: + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException("SETUP or RESUME frame must be received before any others")); + return clientServerConnection.onClose(); + } + } + + private Mono acceptSetup( + ServerSetup serverSetup, + ByteBuf setupFrame, + DuplexConnection clientServerConnection, + int maxFrameLength) { + + if (!SetupFrameCodec.isSupportedVersion(setupFrame)) { + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException( + "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))); + return clientServerConnection.onClose(); + } + + boolean leaseEnabled = leaseConfigurer != null; + if (SetupFrameCodec.honorLease(setupFrame) && !leaseEnabled) { + serverSetup.sendError( + clientServerConnection, new InvalidSetupException("lease is not supported")); + return clientServerConnection.onClose(); + } + + return serverSetup.acceptRSocketSetup( + setupFrame, + clientServerConnection, + (keepAliveHandler, wrappedDuplexConnection) -> { + ConnectionSetupPayload setupPayload = + new DefaultConnectionSetupPayload(setupFrame.retain()); + final InitializingInterceptorRegistry interceptors = this.interceptors; + final ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(wrappedDuplexConnection, interceptors, false); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker(SERVER_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } + + final Sinks.Empty requesterOnAllClosedSink = Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = Sinks.unsafe().empty(); + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asServerConnection(), + payloadDecoder, + StreamIdSupplier.serverSupplier(), + mtu, + maxFrameLength, + maxInboundPayloadSize, + setupPayload.keepAliveInterval(), + setupPayload.keepAliveMaxLifetime(), + keepAliveHandler, + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), requesterOnAllClosedSink.asMono())); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setupPayload, wrappedRSocketRequester) + .onErrorResume( + err -> + Mono.fromRunnable( + () -> + serverSetup.sendError( + wrappedDuplexConnection, rejectedSetupError(err))) + .then(wrappedDuplexConnection.onClose()) + .then(Mono.error(err))) + .doOnNext( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + DuplexConnection clientConnection = multiplexer.asClientConnection(); + + ResponderLeaseTracker responderLeaseTracker = + leaseEnabled + ? new ResponderLeaseTracker(SERVER_TAG, clientConnection, leases.sender) + : null; + + RSocket rSocketResponder = + new RSocketResponder( + clientConnection, + wrappedRSocketHandler, + payloadDecoder, + responderLeaseTracker, + mtu, + maxFrameLength, + maxInboundPayloadSize, + leaseEnabled && leases.sender instanceof TrackingLeaseSender + ? rSocket -> + interceptors.initResponderRequestInterceptor( + rSocket, (RequestInterceptor) leases.sender) + : interceptors::initResponderRequestInterceptor, + responderOnAllClosedSink); + }) + .doFinally(signalType -> setupPayload.release()) + .then(); + }); + } + + private ServerSetup serverSetup(Duration timeout) { + return resume != null ? createSetup(timeout) : new ServerSetup.DefaultServerSetup(timeout); + } + + ServerSetup createSetup(Duration timeout) { + return new ServerSetup.ResumableServerSetup( + timeout, + new SessionManager(), + resume.getSessionDuration(), + resume.getStreamTimeout(), + resume.getStoreFactory(SERVER_TAG), + resume.isCleanupStoreOnKeepAlive()); + } + + private RSocketErrorException rejectedSetupError(Throwable err) { + String msg = err.getMessage(); + return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java new file mode 100644 index 000000000..8e084fe9c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java @@ -0,0 +1,247 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.MIN_MTU_SIZE; +import static io.rsocket.core.StateUtils.isReassembling; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.markReassembled; +import static io.rsocket.core.StateUtils.markReassembling; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +class ReassemblyUtils { + static final String ILLEGAL_REASSEMBLED_PAYLOAD_SIZE = + "Reassembled payload size went out of allowed %s bytes"; + + @SuppressWarnings("ConstantConditions") + static void release(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + frames.release(); + } + } + + @SuppressWarnings({"ConstantConditions", "SynchronizationOnLocalVariableOrMethodParameter"}) + static void synchronizedRelease(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + + synchronized (frames) { + frames.release(); + } + } + } + + static void handleNextSupport( + AtomicLongFieldUpdater updater, + T instance, + Subscription subscription, + CoreSubscriber inboundSubscriber, + PayloadDecoder payloadDecoder, + ByteBufAllocator allocator, + int maxInboundPayloadSize, + ByteBuf frame, + boolean hasFollows, + boolean isLastPayload) { + + long state = updater.get(instance); + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = payloadDecoder.apply(frame); + } catch (Throwable t) { + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + if (isLastPayload) { + instance.handleComplete(); + } + return; + } + + CompositeByteBuf frames = instance.getFrames(); + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), frame, hasFollows, maxInboundPayloadSize); + instance.setFrames(frames); + + long previousState = markReassembling(updater, instance); + if (isTerminated(previousState)) { + instance.setFrames(null); + frames.release(); + return; + } + } else { + try { + frames = + ReassemblyUtils.addFollowingFrame(frames, frame, hasFollows, maxInboundPayloadSize); + } catch (IllegalStateException t) { + if (isTerminated(updater.get(instance))) { + return; + } + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(updater, instance); + if (isTerminated(previousState)) { + return; + } + + instance.setFrames(null); + + Payload payload; + try { + payload = payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + + if (isLastPayload) { + instance.handleComplete(); + } + } + } + + static CompositeByteBuf addFollowingFrame( + CompositeByteBuf frames, + ByteBuf followingFrame, + boolean hasFollows, + int maxInboundPayloadSize) { + int readableBytes = frames.readableBytes(); + if (readableBytes == 0) { + return frames.addComponent(true, followingFrame.retain()); + } else if (maxInboundPayloadSize != Integer.MAX_VALUE + && readableBytes + followingFrame.readableBytes() - FrameHeaderCodec.size() + > maxInboundPayloadSize) { + throw new IllegalStateException( + String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)); + } else if (followingFrame.readableBytes() < MIN_MTU_SIZE - 3 && hasFollows) { + // FIXME: check MIN_MTU_SIZE only (currently fragments have size of 61) + throw new IllegalStateException("Fragment is too small."); + } + + final boolean hasMetadata = FrameHeaderCodec.hasMetadata(followingFrame); + + // skip headers + followingFrame.skipBytes(FrameHeaderCodec.size()); + + // if has metadata, then we have to increase metadata length in containing frames + // CompositeByteBuf + if (hasMetadata) { + final FrameType frameType = FrameHeaderCodec.frameType(frames); + final int lengthFieldPosition = + FrameHeaderCodec.size() + (frameType.hasInitialRequestN() ? Integer.BYTES : 0); + + frames.markReaderIndex(); + frames.skipBytes(lengthFieldPosition); + + final int nextMetadataLength = decodeLength(frames) + decodeLength(followingFrame); + + frames.resetReaderIndex(); + + frames.markWriterIndex(); + frames.writerIndex(lengthFieldPosition); + + encodeLength(frames, nextMetadataLength); + + frames.resetWriterIndex(); + } + + synchronized (frames) { + if (frames.refCnt() > 0) { + followingFrame.retain(); + return frames.addComponent(true, followingFrame); + } else { + throw new IllegalReferenceCountException(0); + } + } + } + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + int length = (byteBuf.readByte() & 0xFF) << 16; + length |= (byteBuf.readByte() & 0xFF) << 8; + length |= byteBuf.readByte() & 0xFF; + return length; + } + + static int assertInboundPayloadSize(int inboundPayloadSize) { + if (inboundPayloadSize < MIN_MTU_SIZE) { + String msg = + String.format( + "The min allowed inboundPayloadSize size is %d bytes, provided: %d", + FrameLengthCodec.FRAME_LENGTH_MASK, inboundPayloadSize); + throw new IllegalArgumentException(msg); + } else { + return inboundPayloadSize; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java new file mode 100644 index 000000000..afad6e0df --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java @@ -0,0 +1,275 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class ReconnectMono extends Mono implements Invalidatable, Disposable, Scannable { + + final Mono source; + final BiConsumer onValueReceived; + final Consumer onValueExpired; + final ResolvingInner resolvingInner; + + ReconnectMono( + Mono source, + Consumer onValueExpired, + BiConsumer onValueReceived) { + this.source = source; + this.onValueExpired = onValueExpired; + this.onValueReceived = onValueReceived; + this.resolvingInner = new ResolvingInner<>(this); + } + + public Mono getSource() { + return source; + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return source; + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + + final boolean isDisposed = isDisposed(); + if (key == Attr.TERMINATED) return isDisposed; + if (key == Attr.ERROR) return this.resolvingInner.t; + + return null; + } + + @Override + public void invalidate() { + this.resolvingInner.invalidate(); + } + + @Override + public void dispose() { + this.resolvingInner.terminate( + new CancellationException("ReconnectMono has already been disposed")); + } + + @Override + public boolean isDisposed() { + return this.resolvingInner.isDisposed(); + } + + @Override + @SuppressWarnings("uncheked") + public void subscribe(CoreSubscriber actual) { + final ResolvingOperator.MonoDeferredResolutionOperator inner = + new ResolvingOperator.MonoDeferredResolutionOperator<>(this.resolvingInner, actual); + actual.onSubscribe(inner); + + this.resolvingInner.observe(inner); + } + + /** + * Block the calling thread indefinitely, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @return the value of this {@code ReconnectMono} + */ + @Override + @Nullable + public T block() { + return block(null); + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@code ReconnectMono} or {@code null} if the timeout is reached and + * the {@code ReconnectMono} has not completed + */ + @Override + @Nullable + @SuppressWarnings("uncheked") + public T block(@Nullable Duration timeout) { + return this.resolvingInner.block(timeout); + } + + /** + * Subscriber that subscribes to the source {@link Mono} to receive its value.
+ * Note that the source is not expected to complete empty, and if this happens, execution will + * terminate with an {@code IllegalStateException}. + */ + static final class ReconnectMainSubscriber implements CoreSubscriber { + + final ResolvingInner parent; + + volatile Subscription s; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMainSubscriber.class, Subscription.class, "s"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ReconnectMainSubscriber.class, "wip"); + + T value; + + ReconnectMainSubscriber(ResolvingInner parent) { + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final T value = this.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + this.doFinally(); + return; + } + + final ResolvingInner p = this.parent; + if (value == null) { + p.terminate(new IllegalStateException("Source completed empty")); + } else { + p.complete(value); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + this.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doFinally(); + // terminate upstream which means retryBackoff has exhausted + this.parent.terminate(t); + } + + @Override + public void onNext(T value) { + if (this.s == Operators.cancelledSubscription()) { + this.parent.doOnValueExpired(value); + return; + } + + this.value = value; + // volatile write and check on racing + this.doFinally(); + } + + void dispose() { + if (Operators.terminate(S, this)) { + this.doFinally(); + } + } + + final void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + if (value != null && this.s == Operators.cancelledSubscription()) { + this.value = null; + this.parent.doOnValueExpired(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + } + + static final class ResolvingInner extends ResolvingOperator implements Scannable { + + final ReconnectMono parent; + final ReconnectMainSubscriber mainSubscriber; + + ResolvingInner(ReconnectMono parent) { + this.parent = parent; + this.mainSubscriber = new ReconnectMainSubscriber<>(this); + } + + @Override + protected void doOnValueExpired(T value) { + this.parent.onValueExpired.accept(value); + } + + @Override + protected void doOnValueResolved(T value) { + this.parent.onValueReceived.accept(value, this.parent); + } + + @Override + protected void doOnDispose() { + this.mainSubscriber.dispose(); + } + + @Override + protected void doSubscribe() { + this.parent.source.subscribe(this.mainSubscriber); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.parent; + return null; + } + } +} + +interface Invalidatable { + + void invalidate(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java new file mode 100644 index 000000000..aab491793 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -0,0 +1,829 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.DISCARD_CONTEXT; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.Objects; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +final class RequestChannelRequesterFlux extends Flux + implements RequesterFrameHandler, + LeasePermitHandler, + CoreSubscriber, + Subscription, + Scannable { + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + final Publisher payloadsPublisher; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelRequesterFlux.class, "state"); + + int streamId; + + boolean isFirstSignal = true; + Payload firstPayload; + + Subscription outboundSubscription; + boolean outboundDone; + Throwable outboundError; + + Context cachedContext; + CoreSubscriber inboundSubscriber; + boolean inboundDone; + long requested; + long produced; + + CompositeByteBuf frames; + + RequestChannelRequesterFlux( + Publisher payloadsPublisher, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadsPublisher = payloadsPublisher; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestChannelFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + Operators.error(actual, e); + return; + } + + this.inboundSubscriber = actual; + this.payloadsPublisher.subscribe(this); + } + + @Override + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + this.inboundSubscriber.onSubscribe(this); + } + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + long previousState = addRequestN(STATE, this, n, this.requesterLeaseTracker == null); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + // do first request + this.outboundSubscription.request(1); + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + p.release(); + return; + } + + if (this.isFirstSignal) { + this.isFirstSignal = false; + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + + if (leaseEnabled) { + this.firstPayload = p; + + final long previousState = markFirstPayloadReceived(STATE, this); + if (isTerminated(previousState)) { + this.firstPayload = null; + p.release(); + return; + } + + requesterLeaseTracker.issue(this); + } else { + final long state = this.state; + if (isTerminated(state)) { + p.release(); + return; + } + // TODO: check if source is Scalar | Callable | Mono + sendFirstPayload(p, extractRequestN(state), false); + } + } else { + sendFollowingPayload(p); + } + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + sendFirstPayload( + firstPayload, extractRequestN(previousState), isOutboundTerminated(previousState)); + return true; + } + + void sendFirstPayload(Payload firstPayload, long initialRequestN, boolean completed) { + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, firstPayload, true)) { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + } catch (IllegalReferenceCountException e) { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + final long previousState = markTerminated(STATE, this); + + firstPayload.release(); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(ut); + + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_CHANNEL, + initialRequestN, + mtu, + firstPayload, + connection, + allocator, + completed); + } catch (Throwable t) { + final long previousState = markTerminated(STATE, this); + + firstPayload.release(); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + sm.remove(streamId, this); + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + // now, this can be terminated in case of the following scenarios: + // + // 1) SendFirst is called synchronously from onNext, thus we can have + // handleError called before we marked first frame sent, thus we may check if + // inboundDone flag is true and exit execution without any further actions: + if (this.inboundDone) { + return; + } + + sm.remove(streamId, this); + + // 2) SendFirst is called asynchronously on the connection event-loop. Thus, we + // need to check if outbound error is present. Note, we check outboundError since + // in the last scenario, cancellation may terminate the state and async + // onComplete may set outboundDone to true. Thus, we explicitly check for + // outboundError + final Throwable outboundError = this.outboundError; + if (outboundError != null) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, outboundError); + connection.sendFrame(streamId, errorFrame); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, outboundError); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(outboundError); + } else { + // 3) SendFirst is interleaving with cancel. Thus, we need to generate cancel + // frame + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_CHANNEL); + } + } + + return; + } + + if (!completed && isOutboundTerminated(previousState)) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + connection.sendFrame(streamId, completeFrame); + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + connection.sendFrame(streamId, requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + connection.sendFrame(streamId, requestNFrame); + } + } + + final void sendFollowingPayload(Payload followingPayload) { + int streamId = this.streamId; + int mtu = this.mtu; + + try { + if (!isValid(mtu, this.maxFrameLength, followingPayload, true)) { + followingPayload.release(); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + return; + } + } catch (IllegalReferenceCountException e) { + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + + return; + } + + try { + sendReleasingPayload( + streamId, + + // TODO: Should be a different flag in case of the scalar + // source or if we know in advance upstream is mono + FrameType.NEXT, + mtu, + followingPayload, + this.connection, + allocator, + true); + } catch (Throwable e) { + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + } + } + + void propagateErrorSafely(Throwable t) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + if (!this.inboundDone) { + synchronized (this) { + if (!this.inboundDone) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + + @Override + public final void cancel() { + if (!tryCancel()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + } + + boolean tryCancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return false; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + if (!isReadyToSendFirstFrame(previousState) && isFirstPayloadReceived(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + // no need to send anything, since we have not started a stream yet (no logical wire) + return false; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final boolean firstFrameSent = isFirstFrameSent(previousState); + if (firstFrameSent) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.connection.sendFrame(streamId, cancelFrame); + } + + return firstFrameSent; + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundError = t; + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (this.isFirstSignal) { + this.inboundDone = true; + this.inboundSubscriber.onError(t); + return; + } else if (!isReadyToSendFirstFrame(previousState)) { + // first signal is received but we are still waiting for lease permit to be issued, + // thus, just propagates error to actual subscriber + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + // propagates error to remote responder + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + if (!isInboundTerminated(previousState)) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, true); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + if (!isFirstFrameSent(previousState)) { + if (!isFirstPayloadReceived(previousState)) { + // first signal, thus, just propagates error to actual subscriber + this.inboundSubscriber.onError(new CancellationException("Empty Source")); + } + return; + } + + final int streamId = this.streamId; + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + + this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated(previousState)) { + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isOutboundTerminated(previousState)) { + this.requesterResponderSupport.remove(this.streamId, this); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final Payload p = this.firstPayload; + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onReject(cause, FrameType.REQUEST_CHANNEL, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.inboundDone) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + ReassemblyUtils.release(this, previousState); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handlePayload(Payload value) { + synchronized (this) { + if (this.inboundDone) { + value.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + value.release(); + if (!tryCancel()) { + return; + } + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(value); + } + } + + @Override + public void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public void handleCancel() { + if (this.outboundDone) { + return; + } + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + final boolean inboundTerminated = isInboundTerminated(previousState); + if (inboundTerminated) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + this.outboundSubscription.cancel(); + + if (inboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + @NonNull + public Context currentContext() { + long state = this.state; + + if (isSubscribedOrTerminated(state)) { + Context cachedContext = this.cachedContext; + if (cachedContext == null) { + cachedContext = + this.inboundSubscriber.currentContext().putAll((ContextView) DISCARD_CONTEXT); + this.cachedContext = cachedContext; + } + return cachedContext; + } + + return Context.empty(); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return state; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestChannelFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java new file mode 100644 index 000000000..32128fee4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -0,0 +1,922 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; +import static reactor.core.Exceptions.TERMINATED; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestChannelResponderSubscriber extends Flux + implements ResponderFrameHandler, Subscription, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestChannelResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final long firstRequest; + + @Nullable final RequestInterceptor requestInterceptor; + + final RSocket handler; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelResponderSubscriber.class, "state"); + + Payload firstPayload; + + Subscription outboundSubscription; + CoreSubscriber inboundSubscriber; + + CompositeByteBuf frames; + + volatile Throwable inboundError; + static final AtomicReferenceFieldUpdater + INBOUND_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RequestChannelResponderSubscriber.class, Throwable.class, "inboundError"); + + boolean inboundDone; + boolean outboundDone; + long requested; + long produced; + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + this.firstRequest = firstRequestN; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + STATE.lazySet(this, REASSEMBLING_FLAG); + } + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + Payload firstPayload, + RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.firstRequest = firstRequestN; + this.firstPayload = firstPayload; + + this.handler = null; + this.frames = null; + } + + @Override + // subscriber from the requestChannel method + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isTerminated(previousState)) { + Throwable t = Exceptions.terminate(INBOUND_ERROR, this); + if (t != TERMINATED) { + //noinspection ConstantConditions + Operators.error(actual, t); + } else { + Operators.error( + actual, + new CancellationException("RequestChannelSubscriber has already been terminated")); + } + return; + } + + if (isSubscribed(previousState)) { + Operators.error( + actual, new IllegalStateException("RequestChannelSubscriber allows only one Subscriber")); + return; + } + + this.inboundSubscriber = actual; + // sends sender as a subscription since every request|cancel signal should be encoded to + // requestNFrame|cancelFrame + actual.onSubscribe(this); + } + + @Override + // subscription to the outbound + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + outboundSubscription.request(this.firstRequest); + } + } + + @Override + public void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + long previousState = StateUtils.addRequestN(STATE, this, n); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null || this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + Payload firstPayload = this.firstPayload; + if (firstPayload != null) { + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else { + inboundSubscriber.onComplete(); + } + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (!hasRequested(previousState) && !isFirstFrameSent(previousState) && this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + inboundSubscriber.onComplete(); + + markFirstFrameSent(STATE, this); + } + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(StateUtils.extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + + previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + long requestN = StateUtils.extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + this.connection.sendFrame(streamId, requestNFrame); + } else { + long firstRequestN = requestN - 1; + if (firstRequestN > 0) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(this.allocator, streamId, firstRequestN); + this.connection.sendFrame(streamId, requestNFrame); + } + } + } + + @Override + // inbound cancellation + public void cancel() { + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + INBOUND_ERROR.lazySet(this, TERMINATED); + return; + } + + if (!isFirstFrameSent(previousState) && !hasRequested(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + + final int streamId = this.streamId; + + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.connection.sendFrame(streamId, cancelFrame); + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleCancel() { + Subscription outboundSubscription = this.outboundSubscription; + if (outboundSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + lazyTerminate(STATE, this); + + this.requesterResponderSupport.remove(this.streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } else { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + } + + final long tryTerminate(boolean isFromInbound) { + Exceptions.addThrowable( + INBOUND_ERROR, this, new CancellationException("Inbound has been canceled")); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return previousState; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + if (isFromInbound) { + frames.release(); + } else { + synchronized (frames) { + frames.release(); + } + } + } + + final Subscription outboundSubscription = this.outboundSubscription; + if (outboundSubscription == null) { + return previousState; + } + + outboundSubscription.cancel(); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + if (isFromInbound) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } else { + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + } + + return previousState; + } + + final void handlePayload(Payload p) { + synchronized (this) { + if (this.inboundDone) { + // payload from network so it has refCnt > 0 + p.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + this.inboundDone = true; + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + this.connection.sendFrame( + streamId, + ErrorFrameCodec.encode( + this.allocator, streamId, new CanceledException(cause.getMessage()))); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause); + } + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(p); + } + } + + @Override + public final void handleError(Throwable t) { + if (this.inboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, t); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + frames.release(); + } + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + } + + @Override + public void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + if (isFirstFrameSent(previousState)) { + this.inboundSubscriber.onComplete(); + } + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + long state = this.state; + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = this.payloadDecoder.apply(frame); + } catch (Throwable t) { + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode(this.allocator, streamId, new CanceledException(t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + return; + } + + this.handlePayload(payload); + if (isLastPayload) { + this.handleComplete(); + } + return; + } + + CompositeByteBuf frames = this.frames; + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + this.allocator.compositeBuffer(), frame, hasFollows, this.maxInboundPayloadSize); + this.frames = frames; + + long previousState = markReassembling(STATE, this); + if (isTerminated(previousState)) { + this.frames = null; + frames.release(); + return; + } + } else { + try { + frames = + ReassemblyUtils.addFollowingFrame( + frames, frame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException e) { + if (isTerminated(this.state)) { + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(STATE, this); + if (isTerminated(previousState)) { + return; + } + + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + return; + } + + if (this.outboundSubscription == null) { + this.firstPayload = payload; + Flux source = this.handler.requestChannel(this); + source.subscribe(this); + } else { + this.handlePayload(payload); + } + + if (isLastPayload) { + this.handleComplete(); + } + } + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState)) { + Operators.onErrorDropped( + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), + this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause:" + e.getMessage())); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, connection, allocator, false); + } catch (Throwable t) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null && !isTerminated(previousState)) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + } + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + boolean wasThrowableAdded = + Exceptions.addThrowable( + INBOUND_ERROR, + this, + new CancellationException("Outbound has terminated with an error")); + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + synchronized (frames) { + frames.release(); + } + } + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (wasThrowableAdded + && isFirstFrameSent(previousState) + && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java new file mode 100644 index 000000000..a13b105b5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -0,0 +1,400 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestResponseRequesterMono extends Mono + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestResponseRequesterMono.class, "state"); + + int streamId; + CoreSubscriber actual; + CompositeByteBuf frames; + boolean done; + + RequestResponseRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestResponseMono allows only a single " + "Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); + + if (isTerminated(previousState) || hasRequested(previousState)) { + return; + } + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstPayload(this.payload); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload); + return true; + } + + void sendFirstPayload(Payload payload) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + final long previousState = markTerminated(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + + payload.release(); + + if (!isTerminated(previousState)) { + this.actual.onError(ut); + } + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + + try { + sendReleasingPayload( + streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, connection, allocator, true); + } catch (Throwable e) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + + this.actual.onError(e); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + sm.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } + } + + @Override + public final void cancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } else if (!isReadyToSendFirstFrame(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload value) { + if (this.done) { + value.release(); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + value.release(); + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + + final CoreSubscriber a = this.actual; + a.onNext(value); + a.onComplete(); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + + this.actual.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + p.release(); + + this.actual.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, cause); + } + + this.actual.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.actual, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.PREFETCH) return 0; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestResponseMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java new file mode 100644 index 000000000..3d9d020ff --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -0,0 +1,358 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestResponseResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestResponseResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final RSocket handler; + + @Nullable final RequestInterceptor requestInterceptor; + + boolean done; + CompositeByteBuf frames; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestResponseResponderSubscriber.class, Subscription.class, "s"); + + public RequestResponseResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + public RequestResponseResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + S.lazySet(this, subscription); + subscription.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(@Nullable Payload p) { + if (this.done) { + if (p != null) { + p.release(); + } + return; + } + + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription() + || !S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + if (p != null) { + p.release(); + } + return; + } + + this.done = true; + + final int streamId = this.streamId; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + this.requesterResponderSupport.remove(streamId, this); + + if (p == null) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); + connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + return; + } + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + currentSubscription.cancel(); + + p.release(); + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + currentSubscription.cancel(); + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, connection, allocator, false); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + } catch (Throwable t) { + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + } + } + + @Override + public void onError(Throwable t) { + if (this.done) { + logger.debug("Dropped error", t); + return; + } + + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription() + || !S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + logger.debug("Dropped error", t); + return; + } + + this.done = true; + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + } + + @Override + public void onComplete() { + onNext(null); + } + + @Override + public void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + if (frames == null) { + return; + } + + try { + ReassemblyUtils.addFollowingFrame(frames, frame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException t) { + S.lazySet(this, Operators.cancelledSubscription()); + + this.requesterResponderSupport.remove(this.streamId, this); + + this.frames = null; + frames.release(); + + logger.debug("Reassembly has failed", t); + + // sends error frame from the responder side to tell that something went wrong + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + return; + } + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReferenceCountUtil.safeRelease(frames); + + logger.debug("Reassembly has failed", t); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + return; + } + + final Mono source = this.handler.requestResponse(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java new file mode 100644 index 000000000..6182ca506 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -0,0 +1,449 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestStreamRequesterFlux extends Flux + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestStreamRequesterFlux.class, "state"); + + int streamId; + CoreSubscriber inboundSubscriber; + CompositeByteBuf frames; + boolean done; + long requested; + long produced; + + RequestStreamRequesterFlux(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestStreamFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); + return; + } + + this.inboundSubscriber = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstPayload(this.payload, n); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload, extractRequestN(previousState)); + return true; + } + + void sendFirstPayload(Payload payload, long initialRequestN) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + final long previousState = markTerminated(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_STREAM, payload.metadata()); + } + + payload.release(); + + if (!isTerminated(previousState)) { + this.inboundSubscriber.onError(ut); + } + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_STREAM, + initialRequestN, + this.mtu, + payload, + connection, + allocator, + false); + } catch (Throwable t) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + this.inboundSubscriber.onError(t); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + return; + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + connection.sendFrame(streamId, requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + connection.sendFrame(streamId, requestNFrame); + } + } + + @Override + public final void cancel() { + final long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + } else if (!isReadyToSendFirstFrame(previousState)) { + // no need to send anything, since the first request has not happened + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload p) { + if (this.done) { + p.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + final IllegalStateException cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(p); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_STREAM, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + + this.inboundSubscriber.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return extractRequestN(state); + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestStreamFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java new file mode 100644 index 000000000..48903ae38 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -0,0 +1,395 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestStreamResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestStreamResponderSubscriber.class); + + final int streamId; + final long firstRequest; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; + + final RSocket handler; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestStreamResponderSubscriber.class, Subscription.class, "s"); + + CompositeByteBuf frames; + boolean done; + + public RequestStreamResponderSubscriber( + int streamId, + long firstRequest, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + public RequestStreamResponderSubscriber( + int streamId, long firstRequest, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + final long firstRequest = this.firstRequest; + S.lazySet(this, subscription); + subscription.request(firstRequest); + } + } + + @Override + public void onNext(Payload p) { + if (this.done) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final DuplexConnection sender = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + if (!this.tryTerminateOnError()) { + return; + } + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + if (!this.tryTerminateOnError()) { + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); + } catch (Throwable t) { + if (!this.tryTerminateOnError()) { + return; + } + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + } + } + + boolean tryTerminateOnError() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return false; + } + + this.done = true; + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return false; + } + + currentSubscription.cancel(); + + return true; + } + + @Override + public void onError(Throwable t) { + if (this.done) { + logger.debug("Dropped error", t); + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + logger.debug("Dropped error", t); + return; + } + + final CompositeByteBuf frames = this.frames; + if (frames != null && frames.refCnt() > 0) { + frames.release(); + } + + final int streamId = this.streamId; + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + return; + } + + final int streamId = this.streamId; + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.connection.sendFrame(streamId, completeFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } + } + + @Override + public void handleRequestN(long n) { + this.s.request(n); + } + + @Override + public final void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + } + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + if (frames == null) { + return; + } + + try { + ReassemblyUtils.addFollowingFrame( + frames, followingFrame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException e) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + + this.frames = null; + frames.release(); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + + logger.debug("Reassembly has failed", e); + return; + } + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + S.lazySet(this, Operators.cancelledSubscription()); + this.done = true; + + final int streamId = this.streamId; + + ReferenceCountUtil.safeRelease(frames); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + Flux source = this.handler.requestStream(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java new file mode 100644 index 000000000..1f7b09af8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java @@ -0,0 +1,43 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import java.util.concurrent.CancellationException; +import reactor.util.annotation.Nullable; + +interface RequesterFrameHandler extends FrameHandler { + + void handlePayload(Payload payload); + + @Override + default void handleCancel() { + handleError( + new CancellationException( + "Cancellation was received but should not be possible for current request type")); + } + + @Override + default void handleRequestN(long n) { + // no ops + } + + @Nullable + CompositeByteBuf getFrames(); + + void setFrames(@Nullable CompositeByteBuf reassembledFrames); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java new file mode 100644 index 000000000..50da83b8f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java @@ -0,0 +1,135 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Availability; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.MissingLeaseException; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Queue; + +final class RequesterLeaseTracker implements Availability { + + final String tag; + final int maximumAllowedAwaitingPermitHandlersNumber; + final Queue awaitingPermitHandlersQueue; + + Lease currentLease = null; + int availableRequests; + + boolean isDisposed; + Throwable t; + + RequesterLeaseTracker(String tag, int maximumAllowedAwaitingPermitHandlersNumber) { + this.tag = tag; + this.maximumAllowedAwaitingPermitHandlersNumber = maximumAllowedAwaitingPermitHandlersNumber; + this.awaitingPermitHandlersQueue = new ArrayDeque<>(); + } + + synchronized void issue(LeasePermitHandler leasePermitHandler) { + if (this.isDisposed) { + leasePermitHandler.handlePermitError(this.t); + return; + } + + final int availableRequests = this.availableRequests; + final Lease l = this.currentLease; + final boolean leaseReceived = l != null; + final boolean isExpired = leaseReceived && isExpired(l); + + if (leaseReceived && availableRequests > 0 && !isExpired) { + if (leasePermitHandler.handlePermit()) { + this.availableRequests = availableRequests - 1; + } + } else { + final Queue queue = this.awaitingPermitHandlersQueue; + if (this.maximumAllowedAwaitingPermitHandlersNumber > queue.size()) { + queue.offer(leasePermitHandler); + } else { + final String tag = this.tag; + final String message; + if (!leaseReceived) { + message = String.format("[%s] Lease was not received yet", tag); + } else if (isExpired) { + message = String.format("[%s] Missing leases. Lease is expired", tag); + } else { + message = + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", + tag, availableRequests); + } + + final Throwable t = new MissingLeaseException(message); + leasePermitHandler.handlePermitError(t); + } + } + } + + void handleLeaseFrame(ByteBuf leaseFrame) { + final int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); + final int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); + final ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); + + synchronized (this) { + final Lease lease = + Lease.create(Duration.ofMillis(timeToLiveMillis), numberOfRequests, metadata); + final Queue queue = this.awaitingPermitHandlersQueue; + + int availableRequests = lease.numberOfRequests(); + + this.currentLease = lease; + if (queue.size() > 0) { + do { + final LeasePermitHandler handler = queue.poll(); + if (handler.handlePermit()) { + availableRequests--; + } + } while (availableRequests > 0 && queue.size() > 0); + } + + this.availableRequests = availableRequests; + } + } + + public synchronized void dispose(Throwable t) { + this.isDisposed = true; + this.t = t; + + final Queue queue = this.awaitingPermitHandlersQueue; + final int size = queue.size(); + + for (int i = 0; i < size; i++) { + final LeasePermitHandler leasePermitHandler = queue.poll(); + + //noinspection ConstantConditions + leasePermitHandler.handlePermitError(t); + } + } + + @Override + public synchronized double availability() { + final Lease lease = this.currentLease; + return lease != null ? this.availableRequests / (double) lease.numberOfRequests() : 0.0d; + } + + static boolean isExpired(Lease currentLease) { + return System.currentTimeMillis() >= currentLease.expirationTime(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java new file mode 100644 index 000000000..bea7dc1aa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -0,0 +1,161 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.Objects; +import java.util.function.Function; +import reactor.util.annotation.Nullable; + +class RequesterResponderSupport { + + private final int mtu; + private final int maxFrameLength; + private final int maxInboundPayloadSize; + private final PayloadDecoder payloadDecoder; + private final ByteBufAllocator allocator; + private final DuplexConnection connection; + @Nullable private final RequestInterceptor requestInterceptor; + + @Nullable final StreamIdSupplier streamIdSupplier; + final IntObjectMap activeStreams; + + public RequesterResponderSupport( + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + PayloadDecoder payloadDecoder, + DuplexConnection connection, + @Nullable StreamIdSupplier streamIdSupplier, + Function requestInterceptorFunction) { + + this.activeStreams = new IntObjectHashMap<>(); + this.mtu = mtu; + this.maxFrameLength = maxFrameLength; + this.maxInboundPayloadSize = maxInboundPayloadSize; + this.payloadDecoder = payloadDecoder; + this.allocator = connection.alloc(); + this.streamIdSupplier = streamIdSupplier; + this.connection = connection; + this.requestInterceptor = requestInterceptorFunction.apply((RSocket) this); + } + + public int getMtu() { + return mtu; + } + + public int getMaxFrameLength() { + return maxFrameLength; + } + + public int getMaxInboundPayloadSize() { + return maxInboundPayloadSize; + } + + public PayloadDecoder getPayloadDecoder() { + return payloadDecoder; + } + + public ByteBufAllocator getAllocator() { + return allocator; + } + + public DuplexConnection getDuplexConnection() { + return connection; + } + + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return null; + } + + @Nullable + public RequestInterceptor getRequestInterceptor() { + return requestInterceptor; + } + + /** + * Issues next {@code streamId} + * + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int getNextStreamId() { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + synchronized (this) { + return streamIdSupplier.nextStreamId(this.activeStreams); + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + /** + * Adds frameHandler and returns issued {@code streamId} back + * + * @param frameHandler to store + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int addAndGetNextStreamId(FrameHandler frameHandler) { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + final IntObjectMap activeStreams = this.activeStreams; + synchronized (this) { + final int streamId = streamIdSupplier.nextStreamId(activeStreams); + + activeStreams.put(streamId, frameHandler); + + return streamId; + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + public synchronized boolean add(int streamId, FrameHandler frameHandler) { + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.putIfAbsent(key, value) without `streamId` boxing + final FrameHandler previousHandler = activeStreams.get(streamId); + if (previousHandler == null) { + activeStreams.put(streamId, frameHandler); + return true; + } + return false; + } + + /** + * Resolves {@link FrameHandler} by {@code streamId} + * + * @param streamId used to resolve {@link FrameHandler} + * @return {@link FrameHandler} or {@code null} + */ + @Nullable + public synchronized FrameHandler get(int streamId) { + return this.activeStreams.get(streamId); + } + + /** + * Removes {@link FrameHandler} if it is present and equals to the given one + * + * @param streamId to lookup for {@link FrameHandler} + * @param frameHandler instance to check with the found one + * @return {@code true} if there is {@link FrameHandler} for the given {@code streamId} and the + * instance equals to the passed one + */ + public synchronized boolean remove(int streamId, FrameHandler frameHandler) { + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.remove(key, value) without `streamId` boxing + final FrameHandler curValue = activeStreams.get(streamId); + if (!Objects.equals(curValue, frameHandler)) { + return false; + } + activeStreams.remove(streamId); + return true; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java new file mode 100644 index 000000000..50bef5b70 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java @@ -0,0 +1,646 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +// A copy of this class exists in io.rsocket.loadbalance + +class ResolvingOperator implements Disposable { + + static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ResolvingOperator.class, "wip"); + + volatile BiConsumer[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ResolvingOperator.class, BiConsumer[].class, "subscribers"); + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_UNSUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_SUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] READY = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] TERMINATED = new BiConsumer[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + public ResolvingOperator() { + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public final void dispose() { + this.terminate(ON_DISPOSE); + } + + @Override + public final boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + public final boolean isPending() { + BiConsumer[] state = this.subscribers; + return state != READY && state != TERMINATED; + } + + @Nullable + public final T valueIfResolved() { + if (this.subscribers == READY) { + T value = this.value; + if (value != null) { + return value; + } + } + + return null; + } + + final void observe(BiConsumer actual) { + for (; ; ) { + final int state = this.add(actual); + + T value = this.value; + + if (state == READY_STATE) { + if (value != null) { + actual.accept(value, null); + return; + } + // value == null means racing between invalidate and this subscriber + // thus, we have to loop again + continue; + } else if (state == TERMINATED_STATE) { + actual.accept(null, this.t); + return; + } + + return; + } + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ResolvingOperator} is completed with an error a RuntimeException + * that wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@link ResolvingOperator} or {@code null} if the timeout is reached + * and the {@link ResolvingOperator} has not completed + * @throws RuntimeException if terminated with error + * @throws IllegalStateException if timed out or {@link Thread} was interrupted with {@link + * InterruptedException} + */ + @Nullable + @SuppressWarnings({"uncheked", "BusyWait"}) + public T block(@Nullable Duration timeout) { + try { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + subscribers = this.subscribers; + + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + final void terminate(Throwable t) { + if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // writes happens before volatile write + this.t = t; + + final BiConsumer[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doOnDispose(); + + this.doFinally(); + + for (BiConsumer consumer : subscribers) { + consumer.accept(null, t); + } + } + + final void complete(T value) { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.doOnValueResolved(value); + + for (BiConsumer consumer : subscribers) { + consumer.accept(value, null); + } + } + + protected void doOnValueResolved(T value) { + // no ops + } + + final void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + if (value != null && isDisposed()) { + this.value = null; + this.doOnValueExpired(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + final void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final BiConsumer[] subscribers = this.subscribers; + + if (subscribers == READY) { + // guarded section to ensure we expire value exactly once if there is racing + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final T value = this.value; + if (value != null) { + this.value = null; + this.doOnValueExpired(value); + } + + int m = 1; + for (; ; ) { + if (isDisposed()) { + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + } + + SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED); + } + } + + protected void doOnValueExpired(T value) { + // no ops + } + + protected void doOnDispose() { + // no ops + } + + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + + final int add(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + BiConsumer[] b = new BiConsumer[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.doSubscribe(); + } + return ADDED_STATE; + } + } + } + + protected void doSubscribe() { + // no ops + } + + @SuppressWarnings("unchecked") + final void remove(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + BiConsumer[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new BiConsumer[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } + + abstract static class DeferredResolution + implements CoreSubscriber, Subscription, Scannable, BiConsumer { + + final ResolvingOperator parent; + final CoreSubscriber actual; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(DeferredResolution.class, "requested"); + + static final long STATE_SUBSCRIBED = -1; + static final long STATE_CANCELLED = Long.MIN_VALUE; + + Subscription s; + boolean done; + + DeferredResolution(ResolvingOperator parent, CoreSubscriber actual) { + this.parent = parent; + this.actual = actual; + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_CANCELLED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_CANCELLED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_CANCELLED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(T payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_CANCELLED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public boolean isCancelled() { + return this.requested == STATE_CANCELLED; + } + + public void cancel() { + long state = REQUESTED.getAndSet(this, STATE_CANCELLED); + if (state == STATE_CANCELLED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + } + } + } + + static class MonoDeferredResolutionOperator extends Operators.MonoSubscriber + implements BiConsumer { + + final ResolvingOperator parent; + + MonoDeferredResolutionOperator(ResolvingOperator parent, CoreSubscriber actual) { + super(actual); + this.parent = parent; + } + + @Override + public void accept(T t, Throwable throwable) { + if (throwable != null) { + onError(throwable); + return; + } + + complete(t); + } + + @Override + public void cancel() { + if (!isCancelled()) { + super.cancel(); + this.parent.remove(this); + } + } + + @Override + public void onComplete() { + if (!isCancelled()) { + this.actual.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + if (isCancelled()) { + Operators.onErrorDropped(t, currentContext()); + } else { + this.actual.onError(t); + } + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.parent; + return super.scanUnsafe(key); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java new file mode 100644 index 000000000..27cc8db9a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +interface ResponderFrameHandler extends FrameHandler { + + Logger logger = LoggerFactory.getLogger(ResponderFrameHandler.class); + + @Override + default void handleComplete() {} + + @Override + default void handleError(Throwable t) { + logger.debug("Dropped error", t); + handleCancel(); + } + + @Override + default void handleRequestN(long n) { + // no ops + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java new file mode 100644 index 000000000..fc7442f4a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Availability; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.LeaseSender; +import io.rsocket.lease.MissingLeaseException; +import reactor.core.Disposable; +import reactor.core.publisher.BaseSubscriber; +import reactor.util.annotation.Nullable; + +final class ResponderLeaseTracker extends BaseSubscriber + implements Disposable, Availability { + + final String tag; + final ByteBufAllocator allocator; + final DuplexConnection connection; + + @Nullable volatile MutableLease currentLease; + + ResponderLeaseTracker(String tag, DuplexConnection connection, LeaseSender leaseSender) { + this.tag = tag; + this.connection = connection; + this.allocator = connection.alloc(); + + leaseSender.send().subscribe(this); + } + + @Nullable + Throwable use() { + final MutableLease lease = this.currentLease; + final String tag = this.tag; + + if (lease == null) { + return new MissingLeaseException(String.format("[%s] Lease was not issued yet", tag)); + } + + if (isExpired(lease)) { + return new MissingLeaseException(String.format("[%s] Missing leases. Lease is expired", tag)); + } + + final int allowedRequests = lease.allowedRequests; + final int remainingRequests = lease.remainingRequests; + if (remainingRequests <= 0) { + return new MissingLeaseException( + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", tag, allowedRequests)); + } + + lease.remainingRequests = remainingRequests - 1; + + return null; + } + + @Override + protected void hookOnNext(Lease lease) { + final int allowedRequests = lease.numberOfRequests(); + final int ttl = lease.timeToLiveInMillis(); + final long expireAt = lease.expirationTime(); + + this.currentLease = new MutableLease(allowedRequests, expireAt); + this.connection.sendFrame( + 0, LeaseFrameCodec.encode(this.allocator, ttl, allowedRequests, lease.metadata())); + } + + @Override + public double availability() { + final MutableLease lease = this.currentLease; + + if (lease == null || isExpired(lease)) { + return 0; + } + + return lease.remainingRequests / (double) lease.allowedRequests; + } + + static boolean isExpired(MutableLease currentLease) { + return System.currentTimeMillis() >= currentLease.expireAt; + } + + static final class MutableLease { + final int allowedRequests; + final long expireAt; + + int remainingRequests; + + MutableLease(int allowedRequests, long expireAt) { + this.allowedRequests = allowedRequests; + this.expireAt = expireAt; + + this.remainingRequests = allowedRequests; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/Resume.java b/rsocket-core/src/main/java/io/rsocket/core/Resume.java new file mode 100644 index 000000000..fa0eedbfa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/Resume.java @@ -0,0 +1,177 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.ResumableFramesStore; +import java.time.Duration; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.util.retry.Retry; + +/** + * Simple holder of configuration settings for the RSocket Resume capability. This can be used to + * configure an {@link RSocketConnector} or an {@link RSocketServer} except for {@link + * #retry(Retry)} and {@link #token(Supplier)} which apply only to the client side. + */ +public class Resume { + private static final Logger logger = LoggerFactory.getLogger(Resume.class); + + private Duration sessionDuration = Duration.ofMinutes(2); + + /* Storage */ + private boolean cleanupStoreOnKeepAlive; + private Function storeFactory; + private Duration streamTimeout = Duration.ofSeconds(10); + + /* Client only */ + private Supplier tokenSupplier = ResumeFrameCodec::generateResumeToken; + private Retry retry = + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1)) + .maxBackoff(Duration.ofSeconds(16)) + .jitter(1.0) + .doBeforeRetry(signal -> logger.debug("Connection error", signal.failure())); + + public Resume() {} + + /** + * The maximum time for a client to keep trying to reconnect. During this time client and server + * continue to store unsent frames to keep the session warm and ready to resume. + * + *

By default this is set to 2 minutes. + * + * @param sessionDuration the max duration for a session + * @return the same instance for method chaining + */ + public Resume sessionDuration(Duration sessionDuration) { + this.sessionDuration = Objects.requireNonNull(sessionDuration); + return this; + } + + /** + * When this property is enabled, hints from {@code KEEPALIVE} frames about how much data has been + * received by the other side, is used to proactively clean frames from the {@link + * #storeFactory(Function) store}. + * + *

By default this is set to {@code false} in which case information from {@code KEEPALIVE} is + * ignored and old frames from the store are removed only when the store runs out of space. + * + * @return the same instance for method chaining + */ + public Resume cleanupStoreOnKeepAlive() { + this.cleanupStoreOnKeepAlive = true; + return this; + } + + /** + * Configure a factory to create the storage for buffering (or persisting) a window of frames that + * may need to be sent again to resume after a dropped connection. + * + *

By default {@link InMemoryResumableFramesStore} is used with its cache size set to 100,000 + * bytes. When the cache fills up, the oldest frames are gradually removed to create space for new + * ones. + * + * @param storeFactory the factory to use to create the store + * @return the same instance for method chaining + */ + public Resume storeFactory( + Function storeFactory) { + this.storeFactory = storeFactory; + return this; + } + + /** + * A {@link reactor.core.publisher.Flux#timeout(Duration) timeout} value to apply to the resumed + * session stream obtained from the {@link #storeFactory(Function) store} after a reconnect. The + * resume stream must not take longer than the specified time to emit each frame. + * + *

By default this is set to 10 seconds. + * + * @param streamTimeout the timeout value for resuming a session stream + * @return the same instance for method chaining + */ + public Resume streamTimeout(Duration streamTimeout) { + this.streamTimeout = Objects.requireNonNull(streamTimeout); + return this; + } + + /** + * Configure the logic for reconnecting. This setting is for use with {@link + * RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is set to: + * + *

{@code
+   * Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1))
+   *     .maxBackoff(Duration.ofSeconds(16))
+   *     .jitter(1.0)
+   * }
+ * + * @param retry the {@code Retry} spec to use when attempting to reconnect + * @return the same instance for method chaining + */ + public Resume retry(Retry retry) { + this.retry = retry; + return this; + } + + /** + * Customize the generation of the resume identification token used to resume. This setting is for + * use with {@link RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is {@code ResumeFrameFlyweight::generateResumeToken}. + * + * @param supplier a custom generator for a resume identification token + * @return the same instance for method chaining + */ + public Resume token(Supplier supplier) { + this.tokenSupplier = supplier; + return this; + } + + // Package private accessors + + Duration getSessionDuration() { + return sessionDuration; + } + + boolean isCleanupStoreOnKeepAlive() { + return cleanupStoreOnKeepAlive; + } + + Function getStoreFactory(String tag) { + return storeFactory != null + ? storeFactory + : token -> new InMemoryResumableFramesStore(tag, token, 100_000); + } + + Duration getStreamTimeout() { + return streamTimeout; + } + + Retry getRetry() { + return retry; + } + + Supplier getTokenSupplier() { + return tokenSupplier; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java new file mode 100644 index 000000000..568dada2e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java @@ -0,0 +1,335 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.isFragmentable; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCounted; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import java.util.function.Consumer; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +final class SendUtils { + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + try { + ReferenceCounted referenceCounted = (ReferenceCounted) data; + referenceCounted.release(); + } catch (Throwable e) { + // ignored + } + } + }; + + static final Context DISCARD_CONTEXT = Operators.enableOnDiscard(null, DROPPED_ELEMENTS_CONSUMER); + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + int mtu, + Payload payload, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean requester) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, false); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, mtu, frameType, streamId, hasMetadata, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + connection.sendFrame(streamId, first); + + boolean complete = frameType == FrameType.NEXT_COMPLETE; + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, true, e); + throw e; + } + connection.sendFrame(streamId, following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_FNF: + requestFrame = + RequestFireAndForgetFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case REQUEST_RESPONSE: + requestFrame = + RequestResponseFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + requestFrame = + PayloadFrameCodec.encode( + allocator, + streamId, + false, + frameType == FrameType.NEXT_COMPLETE, + frameType != FrameType.PAYLOAD, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + connection.sendFrame(streamId, requestFrame); + } + } + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + long initialRequestN, + int mtu, + Payload payload, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean complete) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, true); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, + mtu, + initialRequestN, + frameType, + streamId, + hasMetadata, + slicedMetadata, + slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + connection.sendFrame(streamId, first); + + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + connection.sendFrame(streamId, following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_STREAM: + requestFrame = + RequestStreamFrameCodec.encode( + allocator, + streamId, + false, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + case REQUEST_CHANNEL: + requestFrame = + RequestChannelFrameCodec.encode( + allocator, + streamId, + false, + complete, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + connection.sendFrame(streamId, requestFrame); + } + } + + static void sendTerminalFrame( + int streamId, + FrameType frameType, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean requester, + boolean onFollowingFrame, + Throwable t) { + + if (onFollowingFrame) { + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode fragmented " + + frameType + + " frame. Cause: " + + t.getMessage())); + connection.sendFrame(streamId, errorFrame); + } + } else { + switch (frameType) { + case NEXT_COMPLETE: + case NEXT: + case PAYLOAD: + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode " + frameType + " frame. Cause: " + t.getMessage())); + connection.sendFrame(streamId, errorFrame); + } + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java new file mode 100644 index 000000000..5aae22e89 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -0,0 +1,165 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.keepalive.KeepAliveHandler.*; + +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.exceptions.UnsupportedSetupException; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.resume.*; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.function.BiFunction; +import java.util.function.Function; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; + +abstract class ServerSetup { + + final Duration timeout; + + protected ServerSetup(Duration timeout) { + this.timeout = timeout; + } + + Mono> init(DuplexConnection connection) { + return Mono.>create( + sink -> sink.onRequest(__ -> new SetupHandlingDuplexConnection(connection, sink))) + .timeout(this.timeout) + .or(connection.onClose().then(Mono.error(ClosedChannelException::new))); + } + + abstract Mono acceptRSocketSetup( + ByteBuf frame, + DuplexConnection clientServerConnection, + BiFunction> then); + + abstract Mono acceptRSocketResume(ByteBuf frame, DuplexConnection connection); + + void dispose() {} + + void sendError(DuplexConnection duplexConnection, RSocketErrorException exception) { + duplexConnection.sendErrorAndClose(exception); + duplexConnection.receive().subscribe(); + } + + static class DefaultServerSetup extends ServerSetup { + + DefaultServerSetup(Duration timeout) { + super(timeout); + } + + @Override + public Mono acceptRSocketSetup( + ByteBuf frame, + DuplexConnection duplexConnection, + BiFunction> then) { + + if (SetupFrameCodec.resumeEnabled(frame)) { + sendError(duplexConnection, new UnsupportedSetupException("resume not supported")); + return duplexConnection.onClose(); + } else { + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); + } + } + + @Override + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { + sendError(duplexConnection, new RejectedResumeException("resume not supported")); + return duplexConnection.onClose(); + } + } + + static class ResumableServerSetup extends ServerSetup { + private final SessionManager sessionManager; + private final Duration resumeSessionDuration; + private final Duration resumeStreamTimeout; + private final Function resumeStoreFactory; + private final boolean cleanupStoreOnKeepAlive; + + ResumableServerSetup( + Duration timeout, + SessionManager sessionManager, + Duration resumeSessionDuration, + Duration resumeStreamTimeout, + Function resumeStoreFactory, + boolean cleanupStoreOnKeepAlive) { + super(timeout); + this.sessionManager = sessionManager; + this.resumeSessionDuration = resumeSessionDuration; + this.resumeStreamTimeout = resumeStreamTimeout; + this.resumeStoreFactory = resumeStoreFactory; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + } + + @Override + public Mono acceptRSocketSetup( + ByteBuf frame, + DuplexConnection duplexConnection, + BiFunction> then) { + + if (SetupFrameCodec.resumeEnabled(frame)) { + ByteBuf resumeToken = SetupFrameCodec.resumeToken(frame); + + final ResumableFramesStore resumableFramesStore = resumeStoreFactory.apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "server", resumeToken, duplexConnection, resumableFramesStore); + final ServerRSocketSession serverRSocketSession = + new ServerRSocketSession( + resumeToken, + resumableDuplexConnection, + duplexConnection, + resumableFramesStore, + resumeSessionDuration, + cleanupStoreOnKeepAlive); + + sessionManager.save(serverRSocketSession, resumeToken); + + return then.apply( + new ResumableKeepAliveHandler( + resumableDuplexConnection, serverRSocketSession, serverRSocketSession), + resumableDuplexConnection); + } else { + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); + } + } + + @Override + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { + ServerRSocketSession session = sessionManager.get(ResumeFrameCodec.token(frame)); + if (session != null) { + session.resumeWith(frame, duplexConnection); + return duplexConnection.onClose(); + } else { + sendError(duplexConnection, new RejectedResumeException("unknown resume token")); + return duplexConnection.onClose(); + } + } + + @Override + public void dispose() { + sessionManager.dispose(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java new file mode 100644 index 000000000..3beedf97f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java @@ -0,0 +1,176 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +class SetupHandlingDuplexConnection extends Flux + implements DuplexConnection, CoreSubscriber, Subscription { + + final DuplexConnection source; + final MonoSink> sink; + + Subscription s; + boolean firstFrameReceived = false; + + CoreSubscriber actual; + + boolean done; + Throwable t; + + SetupHandlingDuplexConnection( + DuplexConnection source, MonoSink> sink) { + this.source = source; + this.sink = sink; + + source.receive().subscribe(this); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (done) { + final Throwable t = this.t; + if (t == null) { + Operators.complete(actual); + } else { + Operators.error(actual, t); + } + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public void request(long n) { + if (n != Long.MAX_VALUE) { + actual.onError(new IllegalArgumentException("Only unbounded request is allowed")); + return; + } + + s.request(Long.MAX_VALUE); + } + + @Override + public void cancel() { + source.dispose(); + s.cancel(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(1); + } + } + + @Override + public void onNext(ByteBuf frame) { + if (!firstFrameReceived) { + firstFrameReceived = true; + sink.success(Tuples.of(frame, this)); + return; + } + + actual.onNext(frame); + } + + @Override + public void onError(Throwable t) { + if (done) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.done = true; + this.t = t; + + if (!firstFrameReceived) { + sink.error(t); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onError(t); + } + } + + @Override + public void onComplete() { + if (done) { + return; + } + + this.done = true; + + if (!firstFrameReceived) { + sink.error(new ClosedChannelException()); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onComplete(); + } + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public String toString() { + return "SetupHandlingDuplexConnection{" + "source=" + source + ", done=" + done + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java new file mode 100644 index 000000000..3035696b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java @@ -0,0 +1,255 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.isReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.isSubscribedOrTerminated; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.lazyTerminate; +import static io.rsocket.core.StateUtils.markReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.markSubscribed; +import static io.rsocket.core.StateUtils.markTerminated; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class SlowFireAndForgetRequesterMono extends Mono + implements LeasePermitHandler, Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(SlowFireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + CoreSubscriber actual; + + SlowFireAndForgetRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + long previousState = markSubscribed(STATE, this, !leaseEnabled); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstFrame(p); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstFrame(this.payload); + return true; + } + + void sendFirstFrame(Payload p) { + final CoreSubscriber actual = this.actual; + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(ut); + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + if (isTerminated(this.state)) { + p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isReadyToSendFirstFrame(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePermitError(Throwable cause) { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + + p.release(); + + this.actual.onError(cause); + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java new file mode 100644 index 000000000..2b6a0e09a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java @@ -0,0 +1,493 @@ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +final class StateUtils { + + /** Volatile Long Field bit mask that allows extract flags stored in the field */ + static final long FLAGS_MASK = + 0b111111111111111111111111111111111_0000000000000000000000000000000L; + /** Volatile Long Field bit mask that allows extract int RequestN stored in the field */ + static final long REQUEST_MASK = + 0b000000000000000000000000000000000_1111111111111111111111111111111L; + /** Bit Flag that indicates Requester Producer has been subscribed once */ + static final long SUBSCRIBED_FLAG = + 0b000000000000000000000000000000001_0000000000000000000000000000000L; + /** Bit Flag that indicates that the first payload in RequestChannel scenario is received */ + static final long FIRST_PAYLOAD_RECEIVED_FLAG = + 0b000000000000000000000000000000010_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that the logical stream is ready to send the first initial frame + * (applicable for requester only) + */ + static final long READY_TO_SEND_FIRST_FRAME_FLAG = + 0b000000000000000000000000000000100_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that sent first initial frame was sent (in case of requester) or + * consumed (if responder) + */ + static final long FIRST_FRAME_SENT_FLAG = + 0b000000000000000000000000000001000_0000000000000000000000000000000L; + /** Bit Flag that indicates that there is a frame being reassembled */ + static final long REASSEMBLING_FLAG = + 0b000000000000000000000000000010000_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the inbound is terminated + */ + static final long INBOUND_TERMINATED_FLAG = + 0b000000000000000000000000000100000_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the outbound is terminated + */ + static final long OUTBOUND_TERMINATED_FLAG = + 0b000000000000000000000000001000000_0000000000000000000000000000000L; + /** Initial state for any request operator */ + static final long UNSUBSCRIBED_STATE = + 0b000000000000000000000000000000000_0000000000000000000000000000000L; + /** State that indicates request operator was terminated */ + static final long TERMINATED_STATE = + 0b100000000000000000000000000000000_0000000000000000000000000000000L; + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed(AtomicLongFieldUpdater updater, T instance) { + return markSubscribed(updater, instance, false); + } + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param markPrepared indicates whether the given instance should be marked as prepared + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed( + AtomicLongFieldUpdater updater, T instance, boolean markPrepared) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG) { + return state; + } + + if (updater.compareAndSet( + instance, + state, + state | SUBSCRIBED_FLAG | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0))) { + return state; + } + } + } + + /** + * Indicates that the given stream has already been subscribed once + * + * @param state to check whether stream is subscribed + * @return true if the {@link #SUBSCRIBED_FLAG} flag is set + */ + static boolean isSubscribed(long state) { + return (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_FRAME_SENT_FLAG} flag which indicates + * that the first frame has already set and logical stream has already been established. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been established once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstFrameSent(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_FRAME_SENT_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the first frame which established logical stream has already been sent + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_FRAME_SENT_FLAG} flag is set + */ + static boolean isFirstFrameSent(long state) { + return (state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReadyToSendFirstFrame(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | READY_TO_SEND_FIRST_FRAME_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is prepared for initial frame sending + * @return true if the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag is set + */ + static boolean isReadyToSendFirstFrame(long state) { + return (state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstPayloadReceived(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_PAYLOAD_RECEIVED_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag is set + */ + static boolean isFirstPayloadReceived(long state) { + return (state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #REASSEMBLING_FLAG} flag which indicates that + * there is a payload reassembling in progress. + * + *

Note, the flag will not be added if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembling(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state | REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Removes (if possible) from the given state the {@link #REASSEMBLING_FLAG} flag which indicates + * that a payload reassembly process is completed. + * + *

Note, the flag will not be removed if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembled(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state & ~REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Indicates that a payload reassembly process is completed. + * + * @param state to check whether there is reassembly in progress + * @return true if the {@link #REASSEMBLING_FLAG} flag is set + */ + static boolean isReassembling(long state) { + return (state & REASSEMBLING_FLAG) == REASSEMBLING_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #INBOUND_TERMINATED_FLAG} flag which indicates + * that an inbound channel of a bidirectional stream is terminated. + * + *

Note, this action will have no effect if the stream has already been terminated or if + * the {@link #INBOUND_TERMINATED_FLAG} flag has already been set.
+ * Note, if the outbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markInboundTerminated(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + return state; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | INBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the inbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #INBOUND_TERMINATED_FLAG} set + * @return true if the {@link #INBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isInboundTerminated(long state) { + return (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #OUTBOUND_TERMINATED_FLAG} flag which + * indicates that an outbound channel of a bidirectional stream is terminated. + * + *

Note, this action will have no effect if the stream has already been terminated or if + * the {@link #OUTBOUND_TERMINATED_FLAG} flag has already been set.
+ * Note, if the {@code checkEstablishment} parameter is {@code true} and the logical stream + * is not established, then the result state will be {@link #TERMINATED_STATE}
+ * Note, if the inbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param checkEstablishment indicates whether {@link #FIRST_FRAME_SENT_FLAG} should be checked to + * make final decision + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markOutboundTerminated( + AtomicLongFieldUpdater updater, T instance, boolean checkEstablishment) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + return state; + } + + if ((checkEstablishment && !isFirstFrameSent(state)) + || (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | OUTBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isOutboundTerminated(long state) { + return (state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG; + } + + /** + * Makes current state a {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markTerminated(AtomicLongFieldUpdater updater, T instance) { + return updater.getAndSet(instance, TERMINATED_STATE); + } + + /** + * Makes current state a {@link #TERMINATED_STATE} using {@link + * AtomicLongFieldUpdater#lazySet(Object, long)} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + */ + static void lazyTerminate(AtomicLongFieldUpdater updater, T instance) { + updater.lazySet(instance, TERMINATED_STATE); + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isTerminated(long state) { + return state == TERMINATED_STATE; + } + + /** + * Shortcut for {@link #isSubscribed} {@code ||} {@link #isTerminated} methods + * + * @param state to check flags on + * @return true if state is terminated or has flag subscribed + */ + static boolean isSubscribedOrTerminated(long state) { + return state == TERMINATED_STATE || (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + static long addRequestN(AtomicLongFieldUpdater updater, T instance, long toAdd) { + return addRequestN(updater, instance, toAdd, false); + } + + static long addRequestN( + AtomicLongFieldUpdater updater, T instance, long toAdd, boolean markPrepared) { + long currentState, flags, requestN, nextRequestN; + for (; ; ) { + currentState = updater.get(instance); + + if (currentState == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + requestN = currentState & REQUEST_MASK; + if (requestN == REQUEST_MASK) { + return currentState; + } + + flags = (currentState & FLAGS_MASK) | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0); + nextRequestN = addRequestN(requestN, toAdd); + + if (updater.compareAndSet(instance, currentState, nextRequestN | flags)) { + return currentState; + } + } + } + + static long addRequestN(long a, long b) { + long res = a + b; + if (res < 0 || res > REQUEST_MASK) { + return REQUEST_MASK; + } + return res; + } + + static boolean hasRequested(long state) { + return (state & REQUEST_MASK) > 0; + } + + static long extractRequestN(long state) { + long requestN = state & REQUEST_MASK; + + if (requestN == REQUEST_MASK) { + return REQUEST_MASK; + } + + return requestN; + } + + static boolean isMaxAllowedRequestN(long n) { + return n >= REQUEST_MASK; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java new file mode 100644 index 000000000..15d39c993 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java @@ -0,0 +1,58 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.util.collection.IntObjectMap; + +/** This API is not thread-safe and must be strictly used in serialized fashion */ +final class StreamIdSupplier { + private static final int MASK = 0x7FFFFFFF; + + private long streamId; + + // Visible for testing + StreamIdSupplier(int streamId) { + this.streamId = streamId; + } + + static StreamIdSupplier clientSupplier() { + return new StreamIdSupplier(-1); + } + + static StreamIdSupplier serverSupplier() { + return new StreamIdSupplier(0); + } + + /** + * This methods provides new stream id and ensures there is no intersections with already running + * streams. This methods is not thread-safe. + * + * @param streamIds currently running streams store + * @return next stream id + */ + int nextStreamId(IntObjectMap streamIds) { + int streamId; + do { + this.streamId += 2; + streamId = (int) (this.streamId & MASK); + } while (streamId == 0 || streamIds.containsKey(streamId)); + return streamId; + } + + boolean isBeforeOrCurrent(int streamId) { + return this.streamId >= streamId && streamId > 0; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/package-info.java b/rsocket-core/src/main/java/io/rsocket/core/package-info.java new file mode 100644 index 000000000..29db3f205 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/package-info.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains {@link io.rsocket.core.RSocketConnector RSocketConnector} and {@link + * io.rsocket.core.RSocketServer RSocketServer}, the main classes for connecting to or starting an + * RSocket server. + * + *

This package also contains a package private classes that implement support for the main + * RSocket interactions. + */ +@NonNullApi +package io.rsocket.core; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java new file mode 100644 index 000000000..40cb15dd6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * Application layer logic generating a Reactive Streams {@code onError} event. + * + * @see Error + * Codes + */ +public final class ApplicationErrorException extends RSocketErrorException { + + private static final long serialVersionUID = 7873267740343446585L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public ApplicationErrorException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public ApplicationErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java new file mode 100644 index 000000000..144ef94c6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The Responder canceled the request but may have started processing it (similar to REJECTED but + * doesn't guarantee lack of side-effects). + * + * @see Error + * Codes + */ +public final class CanceledException extends RSocketErrorException { + + private static final long serialVersionUID = 5074789326089722770L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public CanceledException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public CanceledException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CANCELED, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java index 9ba4f781c..1e0167bdd 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java @@ -1,21 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.exceptions; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; -public class ConnectionCloseException extends RSocketException { +/** + * The connection is being terminated. Sender or Receiver of this frame MUST wait for outstanding + * streams to terminate before closing the connection. New requests MAY not be accepted. + * + * @see Error + * Codes + */ +public final class ConnectionCloseException extends RSocketErrorException { - private static final long serialVersionUID = -7659717517940756969L; + private static final long serialVersionUID = -2214953527482377471L; + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ public ConnectionCloseException(String message) { - super(message); - } - - public ConnectionCloseException(String message, Throwable cause) { - super(message, cause); + this(message, null); } - @Override - public int errorCode() { - return ErrorFrameFlyweight.CONNECTION_CLOSE; + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public ConnectionCloseException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_CLOSE, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java new file mode 100644 index 000000000..5cf7cff66 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The connection is being terminated. Sender or Receiver of this frame MAY close the connection + * immediately without waiting for outstanding streams to terminate. + * + * @see Error + * Codes + */ +public final class ConnectionErrorException extends RSocketErrorException implements Retryable { + + private static final long serialVersionUID = 512325887785119744L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public ConnectionErrorException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public ConnectionErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_ERROR, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java new file mode 100644 index 000000000..a72c0ba3b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +public class CustomRSocketException extends RSocketErrorException { + private static final long serialVersionUID = 7873267740343446585L; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @param cause the cause of this exception + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + if (errorCode > ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE + && errorCode < ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]", this); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java index 1e1c9bcf0..5c6eee614 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,44 +13,83 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.exceptions; -import static io.rsocket.frame.ErrorFrameFlyweight.*; +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; -import io.rsocket.Frame; +import io.netty.buffer.ByteBuf; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import java.util.Objects; -public class Exceptions { +/** Utility class that generates an exception from a frame. */ +public final class Exceptions { private Exceptions() {} - public static RuntimeException from(Frame frame) { - final int errorCode = Frame.Error.errorCode(frame); - - String message = frame.getDataUtf8(); - switch (errorCode) { - case APPLICATION_ERROR: - return new ApplicationException(message); - case CANCELED: - return new CancelException(message); - case CONNECTION_CLOSE: - return new ConnectionCloseException(message); - case CONNECTION_ERROR: - return new ConnectionException(message); - case INVALID: - return new InvalidRequestException(message); - case INVALID_SETUP: - return new InvalidSetupException(message); - case REJECTED: - return new RejectedException(message); - case REJECTED_RESUME: - return new RejectedResumeException(message); - case REJECTED_SETUP: - return new RejectedSetupException(message); - case UNSUPPORTED_SETUP: - return new UnsupportedSetupException(message); - default: - return new InvalidRequestException( - "Invalid Error frame: " + errorCode + " '" + message + "'"); + /** + * Create a {@link RSocketErrorException} from a Frame that matches the error code it contains. + * + * @param frame the frame to retrieve the error code and message from + * @return a {@link RSocketErrorException} that matches the error code in the Frame + * @throws NullPointerException if {@code frame} is {@code null} + */ + public static RuntimeException from(int streamId, ByteBuf frame) { + Objects.requireNonNull(frame, "frame must not be null"); + + int errorCode = ErrorFrameCodec.errorCode(frame); + String message = ErrorFrameCodec.dataUtf8(frame); + + if (streamId == 0) { + switch (errorCode) { + case INVALID_SETUP: + return new InvalidSetupException(message); + case UNSUPPORTED_SETUP: + return new UnsupportedSetupException(message); + case REJECTED_SETUP: + return new RejectedSetupException(message); + case REJECTED_RESUME: + return new RejectedResumeException(message); + case CONNECTION_ERROR: + return new ConnectionErrorException(message); + case CONNECTION_CLOSE: + return new ConnectionCloseException(message); + default: + return new IllegalArgumentException( + String.format("Invalid Error frame in Stream ID 0: 0x%08X '%s'", errorCode, message)); + } + } else { + switch (errorCode) { + case APPLICATION_ERROR: + return new ApplicationErrorException(message); + case REJECTED: + return new RejectedException(message); + case CANCELED: + return new CanceledException(message); + case INVALID: + return new InvalidException(message); + default: + if (errorCode >= MIN_USER_ALLOWED_ERROR_CODE + || errorCode <= MAX_USER_ALLOWED_ERROR_CODE) { + return new CustomRSocketException(errorCode, message); + } + return new IllegalArgumentException( + String.format( + "Invalid Error frame in Stream ID %d: 0x%08X '%s'", + streamId, errorCode, message)); + } } } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java new file mode 100644 index 000000000..c556423b9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The request is invalid. + * + * @see Error + * Codes + */ +public final class InvalidException extends RSocketErrorException { + + private static final long serialVersionUID = 8279420324864928243L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public InvalidException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public InvalidException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java index 874247398..b0889c5a6 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,39 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.exceptions; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; -public class InvalidSetupException extends SetupException { +/** + * The Setup frame is invalid for the server (it could be that the client is too recent for the old + * server). + * + * @see Error + * Codes + */ +public final class InvalidSetupException extends SetupException { - private static final long serialVersionUID = -6685677299580579050L; + private static final long serialVersionUID = -6816210006610385251L; + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ public InvalidSetupException(String message) { - super(message); - } - - public InvalidSetupException(String message, Throwable cause) { - super(message, cause); + this(message, null); } - @Override - public int errorCode() { - return ErrorFrameFlyweight.INVALID_SETUP; + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public InvalidSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java deleted file mode 100644 index 869b80acd..000000000 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RSocketException.java +++ /dev/null @@ -1,16 +0,0 @@ -package io.rsocket.exceptions; - -public abstract class RSocketException extends RuntimeException { - - private static final long serialVersionUID = 2912815394105575423L; - - public RSocketException(String message) { - super(message); - } - - public RSocketException(String message, Throwable cause) { - super(message, cause); - } - - public abstract int errorCode(); -} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java index c10411f8e..8bc946e3d 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.exceptions; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; -public class RejectedException extends RSocketException implements Retryable { +/** + * Despite being a valid request, the Responder decided to reject it. The Responder guarantees that + * it didn't process the request. The reason for the rejection is explained in the Error Data + * section. + * + * @see Error + * Codes + */ +public class RejectedException extends RSocketErrorException implements Retryable { - private static final long serialVersionUID = 2773784636669279750L; + private static final long serialVersionUID = 3926231092835143715L; + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ public RejectedException(String message) { - super(message); - } - - public RejectedException(String message, Throwable cause) { - super(message, cause); + this(message, null); } - @Override - public int errorCode() { - return ErrorFrameFlyweight.REJECTED; + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public RejectedException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java index 50a68bacd..44cc55710 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java @@ -1,21 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.exceptions; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; -public class RejectedResumeException extends RSocketException { +/** + * The server rejected the resume, it can specify the reason in the payload. + * + * @see Error + * Codes + */ +public final class RejectedResumeException extends RSocketErrorException { - private static final long serialVersionUID = 6953301234450438491L; + private static final long serialVersionUID = -873684362478544811L; + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ public RejectedResumeException(String message) { - super(message); - } - - public RejectedResumeException(String message, Throwable cause) { - super(message, cause); + this(message, null); } - @Override - public int errorCode() { - return ErrorFrameFlyweight.REJECTED_RESUME; + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public RejectedResumeException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_RESUME, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java index 76da843d4..c09a27e32 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.exceptions; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; -public class RejectedSetupException extends SetupException implements Retryable { +/** + * The server rejected the setup, it can specify the reason in the payload. + * + * @see Error + * Codes + */ +public final class RejectedSetupException extends SetupException implements Retryable { - private static final long serialVersionUID = -4932830657505898008L; + private static final long serialVersionUID = 8757401529926371738L; + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ public RejectedSetupException(String message) { - super(message); - } - - public RejectedSetupException(String message, Throwable cause) { - super(message, cause); + this(message, null); } - @Override - public int errorCode() { - return ErrorFrameFlyweight.REJECTED_SETUP; + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public RejectedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/Retryable.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Retryable.java index 86611bd61..e61fe4f97 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/Retryable.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Retryable.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,7 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.exceptions; -/** Marker interface only */ +/** + * Indicates that an exception is retryable. This interface is a marker and the strategy for + * retrying and operation that causes a {@link Retryable} to be thrown is not specified. + */ public interface Retryable {} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java index 8c7c03532..76dc39a59 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.exceptions; -public abstract class SetupException extends RSocketException { +import io.rsocket.RSocketErrorException; +import reactor.util.annotation.Nullable; - private static final long serialVersionUID = -2928269501877732756L; +/** The root of the setup exception hierarchy. */ +public abstract class SetupException extends RSocketErrorException { - public SetupException(String message) { - super(message); - } + private static final long serialVersionUID = -2928269501877732756L; - public SetupException(String message, Throwable cause) { - super(message, cause); + /** + * Constructs a new exception with the specified error code, message and cause. + * + * @param errorCode the RSocket protocol code + * @param message the message + * @param cause the cause of this exception + */ + public SetupException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java index d5224b64a..7429ccd98 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.exceptions; -import io.rsocket.frame.ErrorFrameFlyweight; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; -public class UnsupportedSetupException extends SetupException { +/** + * Some (or all) of the parameters specified by the client are unsupported by the server. + * + * @see Error + * Codes + */ +public final class UnsupportedSetupException extends SetupException { - private static final long serialVersionUID = -2533421488941132736L; + private static final long serialVersionUID = -1892507835635323415L; + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ public UnsupportedSetupException(String message) { - super(message); - } - - public UnsupportedSetupException(String message, Throwable cause) { - super(message, cause); + this(message, null); } - @Override - public int errorCode() { - return ErrorFrameFlyweight.UNSUPPORTED_SETUP; + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public UnsupportedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.UNSUPPORTED_SETUP, message, cause); } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java index 5a29ce4ab..969aedded 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java @@ -1,18 +1,26 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** + * A hierarchy of exceptions that represent RSocket protocol error codes. + * + * @see Error + * Codes + */ +@NonNullApi package io.rsocket.exceptions; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java deleted file mode 100644 index 38c46b62c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FragmentationDuplexConnection.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.util.collection.IntObjectHashMap; -import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import io.rsocket.frame.FrameHeaderFlyweight; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** Fragments and Re-assembles frames. MTU is number of bytes per fragment. The default is 1024 */ -public class FragmentationDuplexConnection implements DuplexConnection { - - private final DuplexConnection source; - private final IntObjectHashMap frameReassemblers = new IntObjectHashMap<>(); - private final FrameFragmenter frameFragmenter; - - public FragmentationDuplexConnection(DuplexConnection source, int mtu) { - this.source = source; - this.frameFragmenter = new FrameFragmenter(mtu); - } - - public static int getDefaultMTU() { - if (Boolean.getBoolean("io.rsocket.fragmentation.enable")) { - return Integer.getInteger("io.rsocket.fragmentation.mtu", 1024); - } - - return 0; - } - - @Override - public double availability() { - return source.availability(); - } - - @Override - public Mono send(Publisher frames) { - return Flux.from(frames).concatMap(this::sendOne).then(); - } - - @Override - public Mono sendOne(Frame frame) { - if (frameFragmenter.shouldFragment(frame)) { - return source.send(frameFragmenter.fragment(frame)); - } else { - return source.sendOne(frame); - } - } - - @Override - public Flux receive() { - return source - .receive() - .concatMap( - frame -> { - if (FrameHeaderFlyweight.FLAGS_F == (frame.flags() & FrameHeaderFlyweight.FLAGS_F)) { - FrameReassembler frameReassembler = getFrameReassembler(frame); - frameReassembler.append(frame); - return Mono.empty(); - } else if (frameReassemblersContain(frame.getStreamId())) { - FrameReassembler frameReassembler = removeFrameReassembler(frame.getStreamId()); - frameReassembler.append(frame); - Frame reassembled = frameReassembler.reassemble(); - return Mono.just(reassembled); - } else { - return Mono.just(frame); - } - }); - } - - @Override - public Mono close() { - return source.close(); - } - - private synchronized FrameReassembler getFrameReassembler(Frame frame) { - return frameReassemblers.computeIfAbsent(frame.getStreamId(), s -> new FrameReassembler(frame)); - } - - private synchronized FrameReassembler removeFrameReassembler(int streamId) { - return frameReassemblers.remove(streamId); - } - - private synchronized boolean frameReassemblersContain(int streamId) { - return frameReassemblers.containsKey(streamId); - } - - @Override - public Mono onClose() { - return source - .onClose() - .doFinally( - s -> { - synchronized (FragmentationDuplexConnection.this) { - frameReassemblers.values().forEach(FrameReassembler::dispose); - - frameReassemblers.clear(); - } - }); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java deleted file mode 100644 index 865f34f40..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.frame.FrameHeaderFlyweight; -import java.util.function.Consumer; -import javax.annotation.Nullable; -import reactor.core.publisher.Flux; -import reactor.core.publisher.SynchronousSink; - -public class FrameFragmenter { - private final int mtu; - - public FrameFragmenter(int mtu) { - this.mtu = mtu; - } - - public boolean shouldFragment(Frame frame) { - return isFragmentableFrame(frame.getType()) - && FrameHeaderFlyweight.payloadLength(frame.content()) > mtu; - } - - private boolean isFragmentableFrame(FrameType type) { - switch (type) { - case FIRE_AND_FORGET: - case REQUEST_STREAM: - case REQUEST_CHANNEL: - case REQUEST_RESPONSE: - case PAYLOAD: - case NEXT_COMPLETE: - case METADATA_PUSH: - return true; - default: - return false; - } - } - - public Flux fragment(Frame frame) { - - return Flux.generate(new FragmentGenerator(frame)); - } - - private class FragmentGenerator implements Consumer> { - private final Frame frame; - private final int streamId; - private final FrameType frameType; - private final int flags; - - private ByteBuf data; - private @Nullable ByteBuf metadata; - - public FragmentGenerator(Frame frame) { - this.frame = frame.retain(); - this.streamId = frame.getStreamId(); - this.frameType = frame.getType(); - this.flags = frame.flags() & ~FrameHeaderFlyweight.FLAGS_M; - metadata = - frame.hasMetadata() ? FrameHeaderFlyweight.sliceFrameMetadata(frame.content()) : null; - data = FrameHeaderFlyweight.sliceFrameData(frame.content()); - } - - @Override - public void accept(SynchronousSink sink) { - final int dataLength = data.readableBytes(); - - if (metadata != null) { - final int metadataLength = metadata.readableBytes(); - - if (metadataLength > mtu) { - sink.next( - Frame.PayloadFrame.from( - streamId, - frameType, - metadata.readSlice(mtu), - Unpooled.EMPTY_BUFFER, - flags | FrameHeaderFlyweight.FLAGS_M | FrameHeaderFlyweight.FLAGS_F)); - } else { - if (dataLength > mtu - metadataLength) { - sink.next( - Frame.PayloadFrame.from( - streamId, - frameType, - metadata.readSlice(metadataLength), - data.readSlice(mtu - metadataLength), - flags | FrameHeaderFlyweight.FLAGS_M | FrameHeaderFlyweight.FLAGS_F)); - } else { - sink.next( - Frame.PayloadFrame.from( - streamId, - frameType, - metadata.readSlice(metadataLength), - data.readSlice(dataLength), - flags | FrameHeaderFlyweight.FLAGS_M)); - frame.release(); - sink.complete(); - } - } - } else { - if (dataLength > mtu) { - sink.next( - Frame.PayloadFrame.from( - streamId, - frameType, - Unpooled.EMPTY_BUFFER, - data.readSlice(mtu), - flags | FrameHeaderFlyweight.FLAGS_F)); - } else { - sink.next( - Frame.PayloadFrame.from( - streamId, frameType, Unpooled.EMPTY_BUFFER, data.readSlice(dataLength), flags)); - frame.release(); - sink.complete(); - } - } - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java deleted file mode 100644 index 57ebd3ebe..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/FrameReassembler.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; -import io.netty.buffer.PooledByteBufAllocator; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.frame.FrameHeaderFlyweight; -import reactor.core.Disposable; - -/** Assembles Fragmented frames. */ -public class FrameReassembler implements Disposable { - private final FrameType frameType; - private final int streamId; - private final int flags; - private final CompositeByteBuf dataBuffer; - private final CompositeByteBuf metadataBuffer; - - public FrameReassembler(Frame frame) { - this.frameType = frame.getType(); - this.streamId = frame.getStreamId(); - this.flags = frame.flags(); - dataBuffer = PooledByteBufAllocator.DEFAULT.compositeBuffer(); - metadataBuffer = PooledByteBufAllocator.DEFAULT.compositeBuffer(); - } - - public synchronized void append(Frame frame) { - final ByteBuf byteBuf = frame.content(); - final FrameType frameType = FrameHeaderFlyweight.frameType(byteBuf); - final int frameLength = FrameHeaderFlyweight.frameLength(byteBuf); - final int metadataLength = FrameHeaderFlyweight.metadataLength(byteBuf, frameType, frameLength); - final int dataLength = FrameHeaderFlyweight.dataLength(byteBuf, frameType); - if (0 < metadataLength) { - int metadataOffset = FrameHeaderFlyweight.metadataOffset(byteBuf); - if (FrameHeaderFlyweight.hasMetadataLengthField(frameType)) { - metadataOffset += FrameHeaderFlyweight.FRAME_LENGTH_SIZE; - } - metadataBuffer.addComponent(true, byteBuf.retainedSlice(metadataOffset, metadataLength)); - } - if (0 < dataLength) { - final int dataOffset = FrameHeaderFlyweight.dataOffset(byteBuf, frameType, frameLength); - dataBuffer.addComponent(true, byteBuf.retainedSlice(dataOffset, dataLength)); - } - } - - public synchronized Frame reassemble() { - return Frame.PayloadFrame.from(streamId, frameType, metadataBuffer, dataBuffer, flags); - } - - @Override - public void dispose() { - dataBuffer.release(); - metadataBuffer.release(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java b/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java deleted file mode 100644 index d4b5244b1..000000000 --- a/rsocket-core/src/main/java/io/rsocket/fragmentation/package-info.java +++ /dev/null @@ -1,18 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -@javax.annotation.ParametersAreNonnullByDefault -package io.rsocket.fragmentation; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java new file mode 100644 index 000000000..d0d929f0f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java @@ -0,0 +1,12 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class CancelFrameCodec { + private CancelFrameCodec() {} + + public static ByteBuf encode(final ByteBufAllocator allocator, final int streamId) { + return FrameHeaderCodec.encode(allocator, streamId, FrameType.CANCEL, 0); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java new file mode 100644 index 000000000..dcacb57dc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java @@ -0,0 +1,66 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.RSocketErrorException; +import java.nio.charset.StandardCharsets; + +public class ErrorFrameCodec { + + // defined zero stream id error codes + public static final int INVALID_SETUP = 0x00000001; + public static final int UNSUPPORTED_SETUP = 0x00000002; + public static final int REJECTED_SETUP = 0x00000003; + public static final int REJECTED_RESUME = 0x00000004; + public static final int CONNECTION_ERROR = 0x00000101; + public static final int CONNECTION_CLOSE = 0x00000102; + // defined non-zero stream id error codes + public static final int APPLICATION_ERROR = 0x00000201; + public static final int REJECTED = 0x00000202; + public static final int CANCELED = 0x00000203; + public static final int INVALID = 0x00000204; + // defined user-allowed error codes range + public static final int MIN_USER_ALLOWED_ERROR_CODE = 0x00000301; + public static final int MAX_USER_ALLOWED_ERROR_CODE = 0xFFFFFFFE; + + public static ByteBuf encode( + ByteBufAllocator allocator, int streamId, Throwable t, ByteBuf data) { + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.ERROR, 0); + + int errorCode = + t instanceof RSocketErrorException + ? ((RSocketErrorException) t).errorCode() + : APPLICATION_ERROR; + + header.writeInt(errorCode); + + return allocator.compositeBuffer(2).addComponents(true, header, data); + } + + public static ByteBuf encode(ByteBufAllocator allocator, int streamId, Throwable t) { + String message = t.getMessage() == null ? "" : t.getMessage(); + ByteBuf data = ByteBufUtil.writeUtf8(allocator, message); + return encode(allocator, streamId, t, data); + } + + public static int errorCode(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i; + } + + public static ByteBuf data(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf slice = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return slice; + } + + public static String dataUtf8(ByteBuf byteBuf) { + return data(byteBuf).toString(StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java deleted file mode 100644 index d36438e0f..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameFlyweight.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.rsocket.FrameType; -import io.rsocket.exceptions.*; -import java.nio.charset.StandardCharsets; - -public class ErrorFrameFlyweight { - - private ErrorFrameFlyweight() {} - - // defined error codes - public static final int INVALID_SETUP = 0x00000001; - public static final int UNSUPPORTED_SETUP = 0x00000002; - public static final int REJECTED_SETUP = 0x00000003; - public static final int REJECTED_RESUME = 0x00000004; - public static final int CONNECTION_ERROR = 0x00000101; - public static final int CONNECTION_CLOSE = 0x00000102; - public static final int APPLICATION_ERROR = 0x00000201; - public static final int REJECTED = 0x00000202; - public static final int CANCELED = 0x00000203; - public static final int INVALID = 0x00000204; - - // relative to start of passed offset - private static final int ERROR_CODE_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - private static final int PAYLOAD_OFFSET = ERROR_CODE_FIELD_OFFSET + Integer.BYTES; - - public static int computeFrameLength(final int dataLength) { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.ERROR, null, dataLength); - return length + Integer.BYTES; - } - - public static int encode( - final ByteBuf byteBuf, final int streamId, final int errorCode, final ByteBuf data) { - final int frameLength = computeFrameLength(data.readableBytes()); - - int length = - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, 0, FrameType.ERROR, streamId); - - byteBuf.setInt(ERROR_CODE_FIELD_OFFSET, errorCode); - length += Integer.BYTES; - - length += FrameHeaderFlyweight.encodeData(byteBuf, length, data); - - return length; - } - - public static int errorCodeFromException(Throwable ex) { - if (ex instanceof RSocketException) { - return ((RSocketException) ex).errorCode(); - } - - return APPLICATION_ERROR; - } - - public static int errorCode(final ByteBuf byteBuf) { - return byteBuf.getInt(ERROR_CODE_FIELD_OFFSET); - } - - public static int payloadOffset(final ByteBuf byteBuf) { - return FrameHeaderFlyweight.FRAME_HEADER_LENGTH + Integer.BYTES; - } - - public static String message(ByteBuf content) { - return FrameHeaderFlyweight.sliceFrameData(content).toString(StandardCharsets.UTF_8); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java new file mode 100644 index 000000000..418926596 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java @@ -0,0 +1,67 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class ExtensionFrameCodec { + private ExtensionFrameCodec() {} + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + int extendedType, + @Nullable ByteBuf metadata, + ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = FrameHeaderCodec.FLAGS_I; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.EXT, flags); + header.writeInt(extendedType); + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + public static int extendedType(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i; + } + + public static ByteBuf data(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java new file mode 100644 index 000000000..de228b271 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java @@ -0,0 +1,19 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +/** FragmentationFlyweight is used to re-assemble frames */ +public class FragmentationCodec { + public static ByteBuf encode(final ByteBufAllocator allocator, ByteBuf header, ByteBuf data) { + return encode(allocator, header, null, data); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, ByteBuf header, @Nullable ByteBuf metadata, ByteBuf data) { + + final boolean hasMetadata = metadata != null; + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java new file mode 100644 index 000000000..ea011e503 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java @@ -0,0 +1,103 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import reactor.util.annotation.Nullable; + +class FrameBodyCodec { + public static final int FRAME_LENGTH_MASK = 0xFFFFFF; + + private FrameBodyCodec() {} + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + byte b = byteBuf.readByte(); + int length = (b & 0xFF) << 16; + byte b1 = byteBuf.readByte(); + length |= (b1 & 0xFF) << 8; + byte b2 = byteBuf.readByte(); + length |= b2 & 0xFF; + return length; + } + + static ByteBuf encode( + ByteBufAllocator allocator, + final ByteBuf header, + @Nullable ByteBuf metadata, + boolean hasMetadata, + @Nullable ByteBuf data) { + + final boolean addData; + if (data != null) { + if (data.isReadable()) { + addData = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + data.release(); + addData = false; + } + } else { + addData = false; + } + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (hasMetadata) { + int length = metadata.readableBytes(); + encodeLength(header, length); + } + + if (addMetadata && addData) { + return allocator.compositeBuffer(3).addComponents(true, header, metadata, data); + } else if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else if (addData) { + return allocator.compositeBuffer(2).addComponents(true, header, data); + } else { + return header; + } + } + + static ByteBuf metadataWithoutMarking(ByteBuf byteBuf) { + int length = decodeLength(byteBuf); + return byteBuf.readSlice(length); + } + + static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { + if (hasMetadata) { + /*moves reader index*/ + int length = decodeLength(byteBuf); + byteBuf.skipBytes(length); + } + if (byteBuf.readableBytes() > 0) { + return byteBuf.readSlice(byteBuf.readableBytes()); + } else { + return Unpooled.EMPTY_BUFFER; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java new file mode 100644 index 000000000..fc146c935 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -0,0 +1,140 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.reactivestreams.Subscriber; + +/** + * Per connection frame flyweight. + * + *

Not the latest frame layout, but close. Does not include - fragmentation / reassembly - encode + * should remove Type param and have it as part of method name (1 encode per type?) + * + *

Not thread-safe. Assumed to be used single-threaded + */ +public final class FrameHeaderCodec { + /** (I)gnore flag: a value of 0 indicates the protocol can't ignore this frame */ + public static final int FLAGS_I = 0b10_0000_0000; + /** (M)etadata flag: a value of 1 indicates the frame contains metadata */ + public static final int FLAGS_M = 0b01_0000_0000; + /** + * (F)ollows: More fragments follow this fragment (in case of fragmented REQUEST_x or PAYLOAD + * frames) + */ + public static final int FLAGS_F = 0b00_1000_0000; + /** (C)omplete: bit to indicate stream completion ({@link Subscriber#onComplete()}) */ + public static final int FLAGS_C = 0b00_0100_0000; + /** (N)ext: bit to indicate payload or metadata present ({@link Subscriber#onNext(Object)}) */ + public static final int FLAGS_N = 0b00_0010_0000; + + public static final String DISABLE_FRAME_TYPE_CHECK = "io.rsocket.frames.disableFrameTypeCheck"; + private static final int FRAME_FLAGS_MASK = 0b0000_0011_1111_1111; + private static final int FRAME_TYPE_BITS = 6; + private static final int FRAME_TYPE_SHIFT = 16 - FRAME_TYPE_BITS; + private static final int HEADER_SIZE = Integer.BYTES + Short.BYTES; + private static boolean disableFrameTypeCheck; + + static { + disableFrameTypeCheck = Boolean.getBoolean(DISABLE_FRAME_TYPE_CHECK); + } + + private FrameHeaderCodec() {} + + static ByteBuf encodeStreamZero( + final ByteBufAllocator allocator, final FrameType frameType, int flags) { + return encode(allocator, 0, frameType, flags); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, final int streamId, final FrameType frameType, int flags) { + if (!frameType.canHaveMetadata() && ((flags & FLAGS_M) == FLAGS_M)) { + throw new IllegalStateException("bad value for metadata flag"); + } + + short typeAndFlags = (short) (frameType.getEncodedType() << FRAME_TYPE_SHIFT | (short) flags); + + return allocator.buffer().writeInt(streamId).writeShort(typeAndFlags); + } + + public static boolean hasFollows(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_F) == FLAGS_F; + } + + public static boolean hasComplete(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_C) == FLAGS_C; + } + + public static int streamId(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int streamId = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return streamId; + } + + public static int flags(final ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(Integer.BYTES); + short typeAndFlags = byteBuf.readShort(); + byteBuf.resetReaderIndex(); + return typeAndFlags & FRAME_FLAGS_MASK; + } + + public static boolean hasMetadata(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_M) == FLAGS_M; + } + + /** + * faster version of {@link #frameType(ByteBuf)} which does not replace PAYLOAD with synthetic + * type + */ + public static FrameType nativeFrameType(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(Integer.BYTES); + int typeAndFlags = byteBuf.readShort() & 0xFFFF; + FrameType result = FrameType.fromEncodedType(typeAndFlags >> FRAME_TYPE_SHIFT); + byteBuf.resetReaderIndex(); + return result; + } + + public static FrameType frameType(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(Integer.BYTES); + int typeAndFlags = byteBuf.readShort() & 0xFFFF; + + FrameType result = FrameType.fromEncodedType(typeAndFlags >> FRAME_TYPE_SHIFT); + + if (FrameType.PAYLOAD == result) { + final int flags = typeAndFlags & FRAME_FLAGS_MASK; + + boolean complete = FLAGS_C == (flags & FLAGS_C); + boolean next = FLAGS_N == (flags & FLAGS_N); + if (next && complete) { + result = FrameType.NEXT_COMPLETE; + } else if (complete) { + result = FrameType.COMPLETE; + } else if (next) { + result = FrameType.NEXT; + } else { + throw new IllegalArgumentException("Payload must set either or both of NEXT and COMPLETE."); + } + } + + byteBuf.resetReaderIndex(); + + return result; + } + + public static void ensureFrameType(final FrameType frameType, ByteBuf byteBuf) { + if (!disableFrameTypeCheck) { + final FrameType typeInFrame = frameType(byteBuf); + + if (typeInFrame != frameType) { + throw new AssertionError("expected " + frameType + ", but saw " + typeInFrame); + } + } + } + + public static int size() { + return HEADER_SIZE; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java deleted file mode 100644 index 94412e8fc..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderFlyweight.java +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.frame; - -import static io.rsocket.frame.FrameHeaderFlyweight.decodeMetadataLength; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import javax.annotation.Nullable; - -/** - * Per connection frame flyweight. - * - *

Not the latest frame layout, but close. Does not include - fragmentation / reassembly - encode - * should remove Type param and have it as part of method name (1 encode per type?) - * - *

Not thread-safe. Assumed to be used single-threaded - */ -public class FrameHeaderFlyweight { - - private FrameHeaderFlyweight() {} - - public static final int FRAME_HEADER_LENGTH; - - private static final int FRAME_TYPE_BITS = 6; - private static final int FRAME_TYPE_SHIFT = 16 - FRAME_TYPE_BITS; - private static final int FRAME_FLAGS_MASK = 0b0000_0011_1111_1111; - - public static final int FRAME_LENGTH_SIZE = 3; - public static final int FRAME_LENGTH_MASK = 0xFFFFFF; - - private static final int FRAME_LENGTH_FIELD_OFFSET; - private static final int FRAME_TYPE_AND_FLAGS_FIELD_OFFSET; - private static final int STREAM_ID_FIELD_OFFSET; - private static final int PAYLOAD_OFFSET; - - public static final int FLAGS_I = 0b10_0000_0000; - public static final int FLAGS_M = 0b01_0000_0000; - - public static final int FLAGS_F = 0b00_1000_0000; - public static final int FLAGS_C = 0b00_0100_0000; - public static final int FLAGS_N = 0b00_0010_0000; - - static { - FRAME_LENGTH_FIELD_OFFSET = 0; - STREAM_ID_FIELD_OFFSET = FRAME_LENGTH_FIELD_OFFSET + FRAME_LENGTH_SIZE; - FRAME_TYPE_AND_FLAGS_FIELD_OFFSET = STREAM_ID_FIELD_OFFSET + Integer.BYTES; - PAYLOAD_OFFSET = FRAME_TYPE_AND_FLAGS_FIELD_OFFSET + Short.BYTES; - FRAME_HEADER_LENGTH = PAYLOAD_OFFSET; - } - - public static int computeFrameHeaderLength( - final FrameType frameType, @Nullable Integer metadataLength, final int dataLength) { - return PAYLOAD_OFFSET + computeMetadataLength(frameType, metadataLength) + dataLength; - } - - public static int encodeFrameHeader( - final ByteBuf byteBuf, - final int frameLength, - final int flags, - final FrameType frameType, - final int streamId) { - if ((frameLength & ~FRAME_LENGTH_MASK) != 0) { - throw new IllegalArgumentException("Frame length is larger than 24 bits"); - } - - // frame length field needs to be excluded from the length - encodeLength(byteBuf, FRAME_LENGTH_FIELD_OFFSET, frameLength - FRAME_LENGTH_SIZE); - - byteBuf.setInt(STREAM_ID_FIELD_OFFSET, streamId); - short typeAndFlags = (short) (frameType.getEncodedType() << FRAME_TYPE_SHIFT | (short) flags); - byteBuf.setShort(FRAME_TYPE_AND_FLAGS_FIELD_OFFSET, typeAndFlags); - - return FRAME_HEADER_LENGTH; - } - - public static int encodeMetadata( - final ByteBuf byteBuf, - final FrameType frameType, - final int metadataOffset, - final @Nullable ByteBuf metadata) { - int length = 0; - - if (metadata != null) { - final int metadataLength = metadata.readableBytes(); - - int typeAndFlags = byteBuf.getShort(FRAME_TYPE_AND_FLAGS_FIELD_OFFSET); - typeAndFlags |= FLAGS_M; - byteBuf.setShort(FRAME_TYPE_AND_FLAGS_FIELD_OFFSET, (short) typeAndFlags); - - if (hasMetadataLengthField(frameType)) { - encodeLength(byteBuf, metadataOffset, metadataLength); - length += FRAME_LENGTH_SIZE; - } - byteBuf.setBytes(metadataOffset + length, metadata); - length += metadataLength; - } - - return length; - } - - public static int encodeData(final ByteBuf byteBuf, final int dataOffset, final ByteBuf data) { - int length = 0; - final int dataLength = data.readableBytes(); - - if (0 < dataLength) { - byteBuf.setBytes(dataOffset, data); - length += dataLength; - } - - return length; - } - - // only used for types simple enough that they don't have their own FrameFlyweights - public static int encode( - final ByteBuf byteBuf, - final int streamId, - int flags, - final FrameType frameType, - final @Nullable ByteBuf metadata, - final ByteBuf data) { - if (Frame.isFlagSet(flags, FLAGS_M) != (metadata != null)) { - throw new IllegalStateException("bad value for metadata flag"); - } - - final int frameLength = - computeFrameHeaderLength( - frameType, metadata != null ? metadata.readableBytes() : null, data.readableBytes()); - - final FrameType outFrameType; - switch (frameType) { - case PAYLOAD: - throw new IllegalArgumentException( - "Don't encode raw PAYLOAD frames, use NEXT_COMPLETE, COMPLETE or NEXT"); - case NEXT_COMPLETE: - outFrameType = FrameType.PAYLOAD; - flags |= FLAGS_C | FLAGS_N; - break; - case COMPLETE: - outFrameType = FrameType.PAYLOAD; - flags |= FLAGS_C; - break; - case NEXT: - outFrameType = FrameType.PAYLOAD; - flags |= FLAGS_N; - break; - default: - outFrameType = frameType; - break; - } - - int length = encodeFrameHeader(byteBuf, frameLength, flags, outFrameType, streamId); - - length += encodeMetadata(byteBuf, frameType, length, metadata); - length += encodeData(byteBuf, length, data); - - return length; - } - - public static int flags(final ByteBuf byteBuf) { - short typeAndFlags = byteBuf.getShort(FRAME_TYPE_AND_FLAGS_FIELD_OFFSET); - return typeAndFlags & FRAME_FLAGS_MASK; - } - - public static FrameType frameType(final ByteBuf byteBuf) { - int typeAndFlags = byteBuf.getShort(FRAME_TYPE_AND_FLAGS_FIELD_OFFSET); - FrameType result = FrameType.from(typeAndFlags >> FRAME_TYPE_SHIFT); - - if (FrameType.PAYLOAD == result) { - final int flags = typeAndFlags & FRAME_FLAGS_MASK; - - boolean complete = FLAGS_C == (flags & FLAGS_C); - boolean next = FLAGS_N == (flags & FLAGS_N); - if (next && complete) { - result = FrameType.NEXT_COMPLETE; - } else if (complete) { - result = FrameType.COMPLETE; - } else if (next) { - result = FrameType.NEXT; - } else { - throw new IllegalArgumentException("Payload must set either or both of NEXT and COMPLETE."); - } - } - - return result; - } - - public static int streamId(final ByteBuf byteBuf) { - return byteBuf.getInt(STREAM_ID_FIELD_OFFSET); - } - - public static ByteBuf sliceFrameData(final ByteBuf byteBuf) { - final FrameType frameType = frameType(byteBuf); - final int frameLength = frameLength(byteBuf); - final int dataLength = dataLength(byteBuf, frameType); - final int dataOffset = dataOffset(byteBuf, frameType, frameLength); - ByteBuf result = Unpooled.EMPTY_BUFFER; - - if (0 < dataLength) { - result = byteBuf.slice(dataOffset, dataLength); - } - - return result; - } - - public static @Nullable ByteBuf sliceFrameMetadata(final ByteBuf byteBuf) { - final FrameType frameType = frameType(byteBuf); - final int frameLength = frameLength(byteBuf); - final @Nullable Integer metadataLength = metadataLength(byteBuf, frameType, frameLength); - - if (metadataLength == null) { - return null; - } - - int metadataOffset = metadataOffset(byteBuf); - if (hasMetadataLengthField(frameType)) { - metadataOffset += FRAME_LENGTH_SIZE; - } - ByteBuf result = Unpooled.EMPTY_BUFFER; - - if (0 < metadataLength) { - result = byteBuf.slice(metadataOffset, metadataLength); - } - - return result; - } - - public static int frameLength(final ByteBuf byteBuf) { - // frame length field was excluded from the length so we will add it to represent - // the entire block - return decodeLength(byteBuf, FRAME_LENGTH_FIELD_OFFSET) + FRAME_LENGTH_SIZE; - } - - private static int metadataFieldLength(ByteBuf byteBuf, FrameType frameType, int frameLength) { - return computeMetadataLength(frameType, metadataLength(byteBuf, frameType, frameLength)); - } - - public static @Nullable Integer metadataLength( - ByteBuf byteBuf, FrameType frameType, int frameLength) { - if (!hasMetadataLengthField(frameType)) { - return frameLength - metadataOffset(byteBuf); - } else { - return decodeMetadataLength(byteBuf, metadataOffset(byteBuf)); - } - } - - static @Nullable Integer decodeMetadataLength(final ByteBuf byteBuf, final int metadataOffset) { - int flags = flags(byteBuf); - if (FLAGS_M == (FLAGS_M & flags)) { - return decodeLength(byteBuf, metadataOffset); - } else { - return null; - } - } - - private static int computeMetadataLength(FrameType frameType, final @Nullable Integer length) { - if (!hasMetadataLengthField(frameType)) { - // Frames with only metadata does not need metadata length field - return length != null ? length : 0; - } else { - return length == null ? 0 : length + FRAME_LENGTH_SIZE; - } - } - - public static boolean hasMetadataLengthField(FrameType frameType) { - return frameType.canHaveData(); - } - - public static void encodeLength(final ByteBuf byteBuf, final int offset, final int length) { - if ((length & ~FRAME_LENGTH_MASK) != 0) { - throw new IllegalArgumentException("Length is larger than 24 bits"); - } - // Write each byte separately in reverse order, this mean we can write 1 << 23 without - // overflowing. - byteBuf.setByte(offset, length >> 16); - byteBuf.setByte(offset + 1, length >> 8); - byteBuf.setByte(offset + 2, length); - } - - private static int decodeLength(final ByteBuf byteBuf, final int offset) { - int length = (byteBuf.getByte(offset) & 0xFF) << 16; - length |= (byteBuf.getByte(offset + 1) & 0xFF) << 8; - length |= byteBuf.getByte(offset + 2) & 0xFF; - return length; - } - - public static int dataLength(final ByteBuf byteBuf, final FrameType frameType) { - return dataLength(byteBuf, frameType, payloadOffset(byteBuf)); - } - - static int dataLength(final ByteBuf byteBuf, final FrameType frameType, final int payloadOffset) { - final int frameLength = frameLength(byteBuf); - final int metadataLength = metadataFieldLength(byteBuf, frameType, frameLength); - - return frameLength - metadataLength - payloadOffset; - } - - public static int payloadLength(final ByteBuf byteBuf) { - final int frameLength = frameLength(byteBuf); - final int payloadOffset = payloadOffset(byteBuf); - - return frameLength - payloadOffset; - } - - private static int payloadOffset(final ByteBuf byteBuf) { - int typeAndFlags = byteBuf.getShort(FRAME_TYPE_AND_FLAGS_FIELD_OFFSET); - FrameType frameType = FrameType.from(typeAndFlags >> FRAME_TYPE_SHIFT); - int result = PAYLOAD_OFFSET; - - switch (frameType) { - case SETUP: - result = SetupFrameFlyweight.payloadOffset(byteBuf); - break; - case ERROR: - result = ErrorFrameFlyweight.payloadOffset(byteBuf); - break; - case LEASE: - result = LeaseFrameFlyweight.payloadOffset(byteBuf); - break; - case KEEPALIVE: - result = KeepaliveFrameFlyweight.payloadOffset(byteBuf); - break; - case REQUEST_RESPONSE: - case FIRE_AND_FORGET: - case REQUEST_STREAM: - case REQUEST_CHANNEL: - result = RequestFrameFlyweight.payloadOffset(frameType, byteBuf); - break; - case REQUEST_N: - result = RequestNFrameFlyweight.payloadOffset(byteBuf); - break; - } - - return result; - } - - public static int metadataOffset(final ByteBuf byteBuf) { - return payloadOffset(byteBuf); - } - - public static int dataOffset(ByteBuf byteBuf, FrameType frameType, int frameLength) { - return payloadOffset(byteBuf) + metadataFieldLength(byteBuf, frameType, frameLength); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java new file mode 100644 index 000000000..f6c19c8ee --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java @@ -0,0 +1,54 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +/** + * Some transports like TCP aren't framed, and require a length. This is used by DuplexConnections + * for transports that need to send length + */ +public class FrameLengthCodec { + public static final int FRAME_LENGTH_MASK = 0xFFFFFF; + public static final int FRAME_LENGTH_SIZE = 3; + + private FrameLengthCodec() {} + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + int length = (byteBuf.readByte() & 0xFF) << 16; + length |= (byteBuf.readByte() & 0xFF) << 8; + length |= byteBuf.readByte() & 0xFF; + return length; + } + + public static ByteBuf encode(ByteBufAllocator allocator, int length, ByteBuf frame) { + ByteBuf buffer = allocator.buffer(); + encodeLength(buffer, length); + return allocator.compositeBuffer(2).addComponents(true, buffer, frame); + } + + public static int length(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int length = decodeLength(byteBuf); + byteBuf.resetReaderIndex(); + return length; + } + + public static ByteBuf frame(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(3); + ByteBuf slice = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return slice; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameType.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameType.java new file mode 100644 index 000000000..8ac743f87 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameType.java @@ -0,0 +1,315 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import java.util.Arrays; + +/** + * Types of Frame that can be sent. + * + * @see Frame + * Types + */ +public enum FrameType { + + /** Reserved. */ + RESERVED(0x00), + + // CONNECTION + + /** + * Sent by client to initiate protocol processing. + * + * @see Setup + * Frame + */ + SETUP(0x01, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA), + + /** + * Sent by Responder to grant the ability to send requests. + * + * @see Lease + * Frame + */ + LEASE(0x02, Flags.CAN_HAVE_METADATA), + + /** + * Connection keepalive. + * + * @see Keepalive + * Frame + */ + KEEPALIVE(0x03, Flags.CAN_HAVE_DATA), + + // START REQUEST + + /** + * Request single response. + * + * @see Request + * Response Frame + */ + REQUEST_RESPONSE( + 0x04, + Flags.CAN_HAVE_DATA + | Flags.CAN_HAVE_METADATA + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + /** + * A single one-way message. + * + * @see Request + * Fire-and-Forget Frame + */ + REQUEST_FNF( + 0x05, + Flags.CAN_HAVE_DATA + | Flags.CAN_HAVE_METADATA + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + /** + * Request a completable stream. + * + * @see Request + * Stream Frame + */ + REQUEST_STREAM( + 0x06, + Flags.CAN_HAVE_METADATA + | Flags.CAN_HAVE_DATA + | Flags.HAS_INITIAL_REQUEST_N + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + /** + * Request a completable stream in both directions. + * + * @see Request + * Channel Frame + */ + REQUEST_CHANNEL( + 0x07, + Flags.CAN_HAVE_METADATA + | Flags.CAN_HAVE_DATA + | Flags.HAS_INITIAL_REQUEST_N + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + // DURING REQUEST + + /** + * Request N more items with Reactive Streams semantics. + * + * @see RequestN + * Frame + */ + REQUEST_N(0x08), + + /** + * Cancel outstanding request. + * + * @see Cancel + * Frame + */ + CANCEL(0x09), + + // RESPONSE + + /** + * Payload on a stream. For example, response to a request, or message on a channel. + * + * @see Payload + * Frame + */ + PAYLOAD(0x0A, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA | Flags.IS_FRAGMENTABLE), + + /** + * Error at connection or application level. + * + * @see Error + * Frame + */ + ERROR(0x0B, Flags.CAN_HAVE_DATA), + + // METADATA + + /** + * Asynchronous Metadata frame. + * + * @see Metadata + * Push Frame + */ + METADATA_PUSH(0x0C, Flags.CAN_HAVE_METADATA), + + // RESUMPTION + + /** + * Replaces SETUP for Resuming Operation (optional). + * + * @see Resume + * Frame + */ + RESUME(0x0D), + + /** + * Sent in response to a RESUME if resuming operation possible (optional). + * + * @see Resume OK + * Frame + */ + RESUME_OK(0x0E), + + // SYNTHETIC PAYLOAD TYPES + + /** A {@link #PAYLOAD} frame with {@code NEXT} flag set. */ + NEXT(0xA0, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA | Flags.IS_FRAGMENTABLE), + + /** A {@link #PAYLOAD} frame with {@code COMPLETE} flag set. */ + COMPLETE(0xB0), + + /** A {@link #PAYLOAD} frame with {@code NEXT} and {@code COMPLETE} flags set. */ + NEXT_COMPLETE(0xC0, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA | Flags.IS_FRAGMENTABLE), + + /** + * Used To Extend more frame types as well as extensions. + * + * @see Extension + * Frame + */ + EXT(0x3F, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA); + + /** The size of the encoded frame type */ + static final int ENCODED_SIZE = 6; + + private static final FrameType[] FRAME_TYPES_BY_ENCODED_TYPE; + + static { + FRAME_TYPES_BY_ENCODED_TYPE = new FrameType[getMaximumEncodedType() + 1]; + + for (FrameType frameType : values()) { + FRAME_TYPES_BY_ENCODED_TYPE[frameType.encodedType] = frameType; + } + } + + private final int encodedType; + private final int flags; + + FrameType(int encodedType) { + this(encodedType, Flags.EMPTY); + } + + FrameType(int encodedType, int flags) { + this.encodedType = encodedType; + this.flags = flags; + } + + /** + * Returns the {@code FrameType} that matches the specified {@code encodedType}. + * + * @param encodedType the encoded type + * @return the {@code FrameType} that matches the specified {@code encodedType} + */ + public static FrameType fromEncodedType(int encodedType) { + FrameType frameType = FRAME_TYPES_BY_ENCODED_TYPE[encodedType]; + + if (frameType == null) { + throw new IllegalArgumentException(String.format("Frame type %d is unknown", encodedType)); + } + + return frameType; + } + + private static int getMaximumEncodedType() { + return Arrays.stream(values()).mapToInt(frameType -> frameType.encodedType).max().orElse(0); + } + + /** + * Whether the frame type can have data. + * + * @return whether the frame type can have data + */ + public boolean canHaveData() { + return Flags.CAN_HAVE_DATA == (flags & Flags.CAN_HAVE_DATA); + } + + /** + * Whether the frame type can have metadata + * + * @return whether the frame type can have metadata + */ + public boolean canHaveMetadata() { + return Flags.CAN_HAVE_METADATA == (flags & Flags.CAN_HAVE_METADATA); + } + + /** + * Returns the encoded type. + * + * @return the encoded type + */ + public int getEncodedType() { + return encodedType; + } + + /** + * Whether the frame type starts with an initial {@code requestN}. + * + * @return wether the frame type starts with an initial {@code requestN} + */ + public boolean hasInitialRequestN() { + return Flags.HAS_INITIAL_REQUEST_N == (flags & Flags.HAS_INITIAL_REQUEST_N); + } + + /** + * Whether the frame type is fragmentable. + * + * @return whether the frame type is fragmentable + */ + public boolean isFragmentable() { + return Flags.IS_FRAGMENTABLE == (flags & Flags.IS_FRAGMENTABLE); + } + + /** + * Whether the frame type is a request type. + * + * @return whether the frame type is a request type + */ + public boolean isRequestType() { + return Flags.IS_REQUEST_TYPE == (flags & Flags.IS_REQUEST_TYPE); + } + + private static class Flags { + private static final int EMPTY = 0b00000; + private static final int CAN_HAVE_DATA = 0b10000; + private static final int CAN_HAVE_METADATA = 0b01000; + private static final int IS_FRAGMENTABLE = 0b00100; + private static final int IS_REQUEST_TYPE = 0b00010; + private static final int HAS_INITIAL_REQUEST_N = 0b00001; + + private Flags() {} + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java new file mode 100644 index 000000000..d581731a3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; + +public class FrameUtil { + + private FrameUtil() {} + + public static String toString(ByteBuf frame) { + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + StringBuilder payload = new StringBuilder(); + + payload + .append("\nFrame => Stream ID: ") + .append(streamId) + .append(" Type: ") + .append(frameType) + .append(" Flags: 0b") + .append(Integer.toBinaryString(FrameHeaderCodec.flags(frame))) + .append(" Length: " + frame.readableBytes()); + + if (frameType.hasInitialRequestN()) { + payload.append(" InitialRequestN: ").append(RequestStreamFrameCodec.initialRequestN(frame)); + } + + if (frameType == FrameType.REQUEST_N) { + payload.append(" RequestN: ").append(RequestNFrameCodec.requestN(frame)); + } + + if (FrameHeaderCodec.hasMetadata(frame)) { + payload.append("\nMetadata:\n"); + + ByteBufUtil.appendPrettyHexDump(payload, getMetadata(frame, frameType)); + } + + payload.append("\nData:\n"); + ByteBufUtil.appendPrettyHexDump(payload, getData(frame, frameType)); + + return payload.toString(); + } + + private static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); + if (hasMetadata) { + ByteBuf metadata; + switch (frameType) { + case REQUEST_FNF: + metadata = RequestFireAndForgetFrameCodec.metadata(frame); + break; + case REQUEST_STREAM: + metadata = RequestStreamFrameCodec.metadata(frame); + break; + case REQUEST_RESPONSE: + metadata = RequestResponseFrameCodec.metadata(frame); + break; + case REQUEST_CHANNEL: + metadata = RequestChannelFrameCodec.metadata(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata = PayloadFrameCodec.metadata(frame); + break; + case METADATA_PUSH: + metadata = MetadataPushFrameCodec.metadata(frame); + break; + case SETUP: + metadata = SetupFrameCodec.metadata(frame); + break; + case LEASE: + metadata = LeaseFrameCodec.metadata(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return metadata; + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + private static ByteBuf getData(ByteBuf frame, FrameType frameType) { + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameCodec.data(frame); + break; + case REQUEST_STREAM: + data = RequestStreamFrameCodec.data(frame); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameCodec.data(frame); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameCodec.data(frame); + break; + // Payload, KeepAlive and synthetic types + case PAYLOAD: + case KEEPALIVE: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameCodec.data(frame); + break; + case SETUP: + data = SetupFrameCodec.data(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return data; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java new file mode 100644 index 000000000..56a93d869 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java @@ -0,0 +1,159 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +class GenericFrameCodec { + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + final Payload payload) { + return encodeReleasingPayload(allocator, frameType, streamId, complete, next, 0, payload); + } + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + int requestN, + final Payload payload) { + + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } + + return encode(allocator, frameType, streamId, false, complete, next, requestN, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return encode(allocator, frameType, streamId, fragmentFollows, false, false, 0, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + int requestN, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + if (fragmentFollows) { + flags |= FrameHeaderCodec.FLAGS_F; + } + + if (complete) { + flags |= FrameHeaderCodec.FLAGS_C; + } + + if (next) { + flags |= FrameHeaderCodec.FLAGS_N; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, frameType, flags); + + if (requestN > 0) { + header.writeInt(requestN); + } + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + static ByteBuf data(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + int idx = byteBuf.readerIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.readerIndex(idx); + return data; + } + + @Nullable + static ByteBuf metadata(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static ByteBuf dataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + static ByteBuf metadataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static int initialRequestN(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int i = byteBuf.skipBytes(FrameHeaderCodec.size()).readInt(); + byteBuf.resetReaderIndex(); + return i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java new file mode 100644 index 000000000..752d5b3eb --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java @@ -0,0 +1,56 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class KeepAliveFrameCodec { + /** + * (R)espond: Set by the sender of the KEEPALIVE, to which the responder MUST reply with a + * KEEPALIVE without the R flag set + */ + public static final int FLAGS_KEEPALIVE_R = 0b00_1000_0000; + + public static final long LAST_POSITION_MASK = 0x8000000000000000L; + + private KeepAliveFrameCodec() {} + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final boolean respond, + final long lastPosition, + final ByteBuf data) { + final int flags = respond ? FLAGS_KEEPALIVE_R : 0; + ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.KEEPALIVE, flags); + + long lp = 0; + if (lastPosition > 0) { + lp |= lastPosition; + } + + header.writeLong(lp); + + return FrameBodyCodec.encode(allocator, header, null, false, data); + } + + public static boolean respondFlag(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + int flags = FrameHeaderCodec.flags(byteBuf); + return (flags & FLAGS_KEEPALIVE_R) == FLAGS_KEEPALIVE_R; + } + + public static long lastPosition(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + byteBuf.markReaderIndex(); + long l = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); + byteBuf.resetReaderIndex(); + return l; + } + + public static ByteBuf data(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + byteBuf.markReaderIndex(); + ByteBuf slice = byteBuf.skipBytes(FrameHeaderCodec.size() + Long.BYTES).slice(); + byteBuf.resetReaderIndex(); + return slice; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/KeepaliveFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/KeepaliveFrameFlyweight.java deleted file mode 100644 index 4da541b2e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/KeepaliveFrameFlyweight.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.rsocket.FrameType; - -public class KeepaliveFrameFlyweight { - public static final int FLAGS_KEEPALIVE_R = 0b00_1000_0000; - - private KeepaliveFrameFlyweight() {} - - private static final int LAST_POSITION_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - private static final int PAYLOAD_OFFSET = LAST_POSITION_OFFSET + Long.BYTES; - - public static int computeFrameLength(final int dataLength) { - return FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.SETUP, null, dataLength) - + Long.BYTES; - } - - public static int encode(final ByteBuf byteBuf, int flags, final ByteBuf data) { - final int frameLength = computeFrameLength(data.readableBytes()); - - int length = - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, flags, FrameType.KEEPALIVE, 0); - - // We don't support resumability, last position is always zero - byteBuf.setLong(length, 0); - length += Long.BYTES; - - length += FrameHeaderFlyweight.encodeData(byteBuf, length, data); - - return length; - } - - public static int payloadOffset(final ByteBuf byteBuf) { - return PAYLOAD_OFFSET; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java new file mode 100644 index 000000000..f20c25d3b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java @@ -0,0 +1,83 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class LeaseFrameCodec { + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final int ttl, + final int numRequests, + @Nullable final ByteBuf metadata) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = + FrameHeaderCodec.encodeStreamZero(allocator, FrameType.LEASE, flags) + .writeInt(ttl) + .writeInt(numRequests); + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else { + return header; + } + } + + public static int ttl(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int ttl = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return ttl; + } + + public static int numRequests(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + // Ttl + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + int numRequests = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return numRequests; + } + + @Nullable + public static ByteBuf metadata(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + if (FrameHeaderCodec.hasMetadata(byteBuf)) { + byteBuf.markReaderIndex(); + // Ttl + Num of requests + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES * 2); + ByteBuf metadata = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return metadata; + } else { + return null; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java deleted file mode 100644 index 5312ef821..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameFlyweight.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.rsocket.FrameType; - -public class LeaseFrameFlyweight { - private LeaseFrameFlyweight() {} - - // relative to start of passed offset - private static final int TTL_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - private static final int NUM_REQUESTS_FIELD_OFFSET = TTL_FIELD_OFFSET + Integer.BYTES; - private static final int PAYLOAD_OFFSET = NUM_REQUESTS_FIELD_OFFSET + Integer.BYTES; - - public static int computeFrameLength(final int metadataLength) { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.LEASE, metadataLength, 0); - return length + Integer.BYTES * 2; - } - - public static int encode( - final ByteBuf byteBuf, final int ttl, final int numRequests, final ByteBuf metadata) { - final int frameLength = computeFrameLength(metadata.readableBytes()); - - int length = - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, 0, FrameType.LEASE, 0); - - byteBuf.setInt(TTL_FIELD_OFFSET, ttl); - byteBuf.setInt(NUM_REQUESTS_FIELD_OFFSET, numRequests); - - length += Integer.BYTES * 2; - length += FrameHeaderFlyweight.encodeMetadata(byteBuf, FrameType.LEASE, length, metadata); - - return length; - } - - public static int ttl(final ByteBuf byteBuf) { - return byteBuf.getInt(TTL_FIELD_OFFSET); - } - - public static int numRequests(final ByteBuf byteBuf) { - return byteBuf.getInt(NUM_REQUESTS_FIELD_OFFSET); - } - - public static int payloadOffset(final ByteBuf byteBuf) { - return PAYLOAD_OFFSET; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java new file mode 100644 index 000000000..d8ffe3eef --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java @@ -0,0 +1,43 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; + +public class MetadataPushFrameCodec { + + public static ByteBuf encodeReleasingPayload(ByteBufAllocator allocator, Payload payload) { + if (!payload.hasMetadata()) { + throw new IllegalStateException( + "Metadata push requires to have metadata present" + " in the given Payload"); + } + final ByteBuf metadata = payload.metadata().retain(); + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + metadata.release(); + throw e; + } + return encode(allocator, metadata); + } + + public static ByteBuf encode(ByteBufAllocator allocator, ByteBuf metadata) { + ByteBuf header = + FrameHeaderCodec.encodeStreamZero( + allocator, FrameType.METADATA_PUSH, FrameHeaderCodec.FLAGS_M); + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } + + public static ByteBuf metadata(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int headerSize = FrameHeaderCodec.size(); + int metadataLength = byteBuf.readableBytes() - headerSize; + byteBuf.skipBytes(headerSize); + ByteBuf metadata = byteBuf.readSlice(metadataLength); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java new file mode 100644 index 000000000..1ae9c6750 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java @@ -0,0 +1,56 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class PayloadFrameCodec { + + private PayloadFrameCodec() {} + + public static ByteBuf encodeNextReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, false, payload); + } + + public static ByteBuf encodeNextCompleteReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, true, payload); + } + + static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, boolean complete, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.PAYLOAD, streamId, complete, true, payload); + } + + public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { + return encode(allocator, streamId, false, true, false, null, null); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.PAYLOAD, streamId, fragmentFollows, complete, next, 0, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java new file mode 100644 index 000000000..60906083d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java @@ -0,0 +1,69 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestChannelFrameCodec { + + private RequestChannelFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, + int streamId, + boolean complete, + long initialRequestN, + Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_CHANNEL, streamId, complete, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_CHANNEL, + streamId, + fragmentFollows, + complete, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java new file mode 100644 index 000000000..b91199179 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java @@ -0,0 +1,38 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestFireAndForgetFrameCodec { + + private RequestFireAndForgetFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_FNF, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_FNF, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFrameFlyweight.java deleted file mode 100644 index ac474598b..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestFrameFlyweight.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import javax.annotation.Nullable; - -public class RequestFrameFlyweight { - - private RequestFrameFlyweight() {} - - // relative to start of passed offset - private static final int INITIAL_REQUEST_N_FIELD_OFFSET = - FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - - public static int computeFrameLength( - final FrameType type, final @Nullable Integer metadataLength, final int dataLength) { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(type, metadataLength, dataLength); - - if (type.hasInitialRequestN()) { - length += Integer.BYTES; - } - - return length; - } - - public static int encode( - final ByteBuf byteBuf, - final int streamId, - int flags, - final FrameType type, - final int initialRequestN, - final @Nullable ByteBuf metadata, - final ByteBuf data) { - if (Frame.isFlagSet(flags, FrameHeaderFlyweight.FLAGS_M) != (metadata != null)) { - throw new IllegalArgumentException("metadata flag set incorrectly"); - } - - final int frameLength = - computeFrameLength( - type, metadata != null ? metadata.readableBytes() : null, data.readableBytes()); - - int length = - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, flags, type, streamId); - - byteBuf.setInt(INITIAL_REQUEST_N_FIELD_OFFSET, initialRequestN); - length += Integer.BYTES; - - length += FrameHeaderFlyweight.encodeMetadata(byteBuf, type, length, metadata); - length += FrameHeaderFlyweight.encodeData(byteBuf, length, data); - - return length; - } - - public static int encode( - final ByteBuf byteBuf, - final int streamId, - final int flags, - final FrameType type, - final @Nullable ByteBuf metadata, - final ByteBuf data) { - if (Frame.isFlagSet(flags, FrameHeaderFlyweight.FLAGS_M) != (metadata != null)) { - throw new IllegalArgumentException("metadata flag set incorrectly"); - } - if (type.hasInitialRequestN()) { - throw new AssertionError(type + " must not be encoded without initial request N"); - } - final int frameLength = - computeFrameLength( - type, metadata != null ? metadata.readableBytes() : null, data.readableBytes()); - - int length = - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, flags, type, streamId); - - length += FrameHeaderFlyweight.encodeMetadata(byteBuf, type, length, metadata); - length += FrameHeaderFlyweight.encodeData(byteBuf, length, data); - - return length; - } - - public static int initialRequestN(final ByteBuf byteBuf) { - return byteBuf.getInt(INITIAL_REQUEST_N_FIELD_OFFSET); - } - - public static int payloadOffset(final FrameType type, final ByteBuf byteBuf) { - int result = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - - if (type.hasInitialRequestN()) { - result += Integer.BYTES; - } - - return result; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java new file mode 100644 index 000000000..66bdd46f4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java @@ -0,0 +1,30 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class RequestNFrameCodec { + private RequestNFrameCodec() {} + + public static ByteBuf encode( + final ByteBufAllocator allocator, final int streamId, long requestN) { + + if (requestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.REQUEST_N, 0); + return header.writeInt(reqN); + } + + public static long requestN(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.REQUEST_N, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i == Integer.MAX_VALUE ? Long.MAX_VALUE : i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java deleted file mode 100644 index 2e2daef4a..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameFlyweight.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.frame; - -import io.netty.buffer.ByteBuf; -import io.rsocket.FrameType; - -public class RequestNFrameFlyweight { - private RequestNFrameFlyweight() {} - - // relative to start of passed offset - private static final int REQUEST_N_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - - public static int computeFrameLength() { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.REQUEST_N, 0, 0); - - return length + Integer.BYTES; - } - - public static int encode(final ByteBuf byteBuf, final int streamId, final int requestN) { - final int frameLength = computeFrameLength(); - - int length = - FrameHeaderFlyweight.encodeFrameHeader( - byteBuf, frameLength, 0, FrameType.REQUEST_N, streamId); - - byteBuf.setInt(REQUEST_N_FIELD_OFFSET, requestN); - - return length + Integer.BYTES; - } - - public static int requestN(final ByteBuf byteBuf) { - return byteBuf.getInt(REQUEST_N_FIELD_OFFSET); - } - - public static int payloadOffset(final ByteBuf byteBuf) { - return FrameHeaderFlyweight.FRAME_HEADER_LENGTH + Integer.BYTES; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java new file mode 100644 index 000000000..4a37acfd5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java @@ -0,0 +1,37 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestResponseFrameCodec { + + private RequestResponseFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_RESPONSE, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_RESPONSE, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java new file mode 100644 index 000000000..2f5dbf0d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java @@ -0,0 +1,64 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestStreamFrameCodec { + + private RequestStreamFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, long initialRequestN, Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_STREAM, streamId, false, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_STREAM, + streamId, + fragmentFollows, + false, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java new file mode 100644 index 000000000..aae89f7ab --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.util.UUID; + +public class ResumeFrameCodec { + static final int CURRENT_VERSION = SetupFrameCodec.CURRENT_VERSION; + + public static ByteBuf encode( + ByteBufAllocator allocator, + ByteBuf token, + long lastReceivedServerPos, + long firstAvailableClientPos) { + + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME, 0); + byteBuf.writeInt(CURRENT_VERSION); + token.markReaderIndex(); + byteBuf.writeShort(token.readableBytes()); + byteBuf.writeBytes(token); + token.resetReaderIndex(); + byteBuf.writeLong(lastReceivedServerPos); + byteBuf.writeLong(firstAvailableClientPos); + + return byteBuf; + } + + public static int version(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int version = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + + return version; + } + + public static ByteBuf token(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + ByteBuf token = byteBuf.readSlice(tokenLength); + byteBuf.resetReaderIndex(); + + return token; + } + + public static long lastReceivedServerPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + byteBuf.skipBytes(tokenLength); + long lastReceivedServerPos = byteBuf.readLong(); + byteBuf.resetReaderIndex(); + + return lastReceivedServerPos; + } + + public static long firstAvailableClientPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + byteBuf.skipBytes(tokenLength); + // last received server position + byteBuf.skipBytes(Long.BYTES); + long firstAvailableClientPos = byteBuf.readLong(); + byteBuf.resetReaderIndex(); + + return firstAvailableClientPos; + } + + public static ByteBuf generateResumeToken() { + UUID uuid = UUID.randomUUID(); + ByteBuf bb = Unpooled.buffer(16); + bb.writeLong(uuid.getMostSignificantBits()); + bb.writeLong(uuid.getLeastSignificantBits()); + return bb; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java new file mode 100644 index 000000000..2b6951e49 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java @@ -0,0 +1,22 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class ResumeOkFrameCodec { + + public static ByteBuf encode(final ByteBufAllocator allocator, long lastReceivedClientPos) { + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME_OK, 0); + byteBuf.writeLong(lastReceivedClientPos); + return byteBuf; + } + + public static long lastReceivedClientPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME_OK, byteBuf); + byteBuf.markReaderIndex(); + long lastReceivedClientPosition = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); + byteBuf.resetReaderIndex(); + + return lastReceivedClientPosition; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java new file mode 100644 index 000000000..547e2436e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java @@ -0,0 +1,226 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import java.nio.charset.StandardCharsets; +import reactor.util.annotation.Nullable; + +public class SetupFrameCodec { + /** + * A flag used to indicate that the client requires connection resumption, if possible (the frame + * contains a Resume Identification Token) + */ + public static final int FLAGS_RESUME_ENABLE = 0b00_1000_0000; + + /** A flag used to indicate that the client will honor LEASE sent by the server */ + public static final int FLAGS_WILL_HONOR_LEASE = 0b00_0100_0000; + + public static final int CURRENT_VERSION = VersionCodec.encode(1, 0); + + private static final int VERSION_FIELD_OFFSET = FrameHeaderCodec.size(); + private static final int KEEPALIVE_INTERVAL_FIELD_OFFSET = VERSION_FIELD_OFFSET + Integer.BYTES; + private static final int KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET = + KEEPALIVE_INTERVAL_FIELD_OFFSET + Integer.BYTES; + private static final int VARIABLE_DATA_OFFSET = + KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET + Integer.BYTES; + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final boolean lease, + final int keepaliveInterval, + final int maxLifetime, + final String metadataMimeType, + final String dataMimeType, + final Payload setupPayload) { + return encode( + allocator, + lease, + keepaliveInterval, + maxLifetime, + Unpooled.EMPTY_BUFFER, + metadataMimeType, + dataMimeType, + setupPayload); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final boolean lease, + final int keepaliveInterval, + final int maxLifetime, + final ByteBuf resumeToken, + final String metadataMimeType, + final String dataMimeType, + final Payload setupPayload) { + + final ByteBuf data = setupPayload.sliceData(); + final boolean hasMetadata = setupPayload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? setupPayload.sliceMetadata() : null; + + int flags = 0; + + if (resumeToken.readableBytes() > 0) { + flags |= FLAGS_RESUME_ENABLE; + } + + if (lease) { + flags |= FLAGS_WILL_HONOR_LEASE; + } + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.SETUP, flags); + + header.writeInt(CURRENT_VERSION).writeInt(keepaliveInterval).writeInt(maxLifetime); + + if ((flags & FLAGS_RESUME_ENABLE) != 0) { + resumeToken.markReaderIndex(); + header.writeShort(resumeToken.readableBytes()).writeBytes(resumeToken); + resumeToken.resetReaderIndex(); + } + + // Write metadata mime-type + int length = ByteBufUtil.utf8Bytes(metadataMimeType); + header.writeByte(length); + ByteBufUtil.writeUtf8(header, metadataMimeType); + + // Write data mime-type + length = ByteBufUtil.utf8Bytes(dataMimeType); + header.writeByte(length); + ByteBufUtil.writeUtf8(header, dataMimeType); + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + public static int version(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.SETUP, byteBuf); + byteBuf.markReaderIndex(); + int version = byteBuf.skipBytes(VERSION_FIELD_OFFSET).readInt(); + byteBuf.resetReaderIndex(); + return version; + } + + public static String humanReadableVersion(ByteBuf byteBuf) { + int encodedVersion = version(byteBuf); + return VersionCodec.major(encodedVersion) + "." + VersionCodec.minor(encodedVersion); + } + + public static boolean isSupportedVersion(ByteBuf byteBuf) { + return CURRENT_VERSION == version(byteBuf); + } + + public static int resumeTokenLength(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int tokenLength = byteBuf.skipBytes(VARIABLE_DATA_OFFSET).readShort() & 0xFFFF; + byteBuf.resetReaderIndex(); + return tokenLength; + } + + public static int keepAliveInterval(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int keepAliveInterval = byteBuf.skipBytes(KEEPALIVE_INTERVAL_FIELD_OFFSET).readInt(); + byteBuf.resetReaderIndex(); + return keepAliveInterval; + } + + public static int keepAliveMaxLifetime(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int keepAliveMaxLifetime = byteBuf.skipBytes(KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET).readInt(); + byteBuf.resetReaderIndex(); + return keepAliveMaxLifetime; + } + + public static boolean honorLease(ByteBuf byteBuf) { + return (FLAGS_WILL_HONOR_LEASE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_WILL_HONOR_LEASE; + } + + public static boolean resumeEnabled(ByteBuf byteBuf) { + return (FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE; + } + + public static ByteBuf resumeToken(ByteBuf byteBuf) { + if (resumeEnabled(byteBuf)) { + byteBuf.markReaderIndex(); + // header + int resumePos = + FrameHeaderCodec.size() + + + // version + Integer.BYTES + + + // keep-alive interval + Integer.BYTES + + + // keep-alive maxLifeTime + Integer.BYTES; + + int tokenLength = byteBuf.skipBytes(resumePos).readShort() & 0xFFFF; + ByteBuf resumeToken = byteBuf.readSlice(tokenLength); + byteBuf.resetReaderIndex(); + return resumeToken; + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + public static String metadataMimeType(ByteBuf byteBuf) { + int skip = bytesToSkipToMimeType(byteBuf); + byteBuf.markReaderIndex(); + int length = byteBuf.skipBytes(skip).readUnsignedByte(); + String mimeType = byteBuf.slice(byteBuf.readerIndex(), length).toString(StandardCharsets.UTF_8); + byteBuf.resetReaderIndex(); + return mimeType; + } + + public static String dataMimeType(ByteBuf byteBuf) { + int skip = bytesToSkipToMimeType(byteBuf); + byteBuf.markReaderIndex(); + int metadataLength = byteBuf.skipBytes(skip).readByte(); + int dataLength = byteBuf.skipBytes(metadataLength).readByte(); + String mimeType = byteBuf.readSlice(dataLength).toString(StandardCharsets.UTF_8); + byteBuf.resetReaderIndex(); + return mimeType; + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + skipToPayload(byteBuf); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + public static ByteBuf data(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + skipToPayload(byteBuf); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + private static int bytesToSkipToMimeType(ByteBuf byteBuf) { + int bytesToSkip = VARIABLE_DATA_OFFSET; + if ((FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE) { + bytesToSkip += resumeTokenLength(byteBuf) + Short.BYTES; + } + return bytesToSkip; + } + + private static void skipToPayload(ByteBuf byteBuf) { + int skip = bytesToSkipToMimeType(byteBuf); + byte length = byteBuf.skipBytes(skip).readByte(); + length = byteBuf.skipBytes(length).readByte(); + byteBuf.skipBytes(length); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java deleted file mode 100644 index cc1edc30c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameFlyweight.java +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.frame; - -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.rsocket.FrameType; -import java.nio.charset.StandardCharsets; - -public class SetupFrameFlyweight { - private SetupFrameFlyweight() {} - - public static final int FLAGS_RESUME_ENABLE = 0b00_1000_0000; - public static final int FLAGS_WILL_HONOR_LEASE = 0b00_0100_0000; - public static final int FLAGS_STRICT_INTERPRETATION = 0b00_0010_0000; - - public static final int VALID_FLAGS = - FLAGS_RESUME_ENABLE | FLAGS_WILL_HONOR_LEASE | FLAGS_STRICT_INTERPRETATION | FLAGS_M; - - public static final int CURRENT_VERSION = VersionFlyweight.encode(1, 0); - - // relative to start of passed offset - private static final int VERSION_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - private static final int KEEPALIVE_INTERVAL_FIELD_OFFSET = VERSION_FIELD_OFFSET + Integer.BYTES; - private static final int MAX_LIFETIME_FIELD_OFFSET = - KEEPALIVE_INTERVAL_FIELD_OFFSET + Integer.BYTES; - private static final int VARIABLE_DATA_OFFSET = MAX_LIFETIME_FIELD_OFFSET + Integer.BYTES; - - public static int computeFrameLength( - final int flags, - final String metadataMimeType, - final String dataMimeType, - final int metadataLength, - final int dataLength) { - return computeFrameLength(flags, 0, metadataMimeType, dataMimeType, metadataLength, dataLength); - } - - private static int computeFrameLength( - final int flags, - final int resumeTokenLength, - final String metadataMimeType, - final String dataMimeType, - final int metadataLength, - final int dataLength) { - int length = - FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.SETUP, metadataLength, dataLength); - - length += Integer.BYTES * 3; - - if ((flags & FLAGS_RESUME_ENABLE) != 0) { - length += Short.BYTES + resumeTokenLength; - } - - length += 1 + metadataMimeType.getBytes(StandardCharsets.UTF_8).length; - length += 1 + dataMimeType.getBytes(StandardCharsets.UTF_8).length; - - return length; - } - - public static int encode( - final ByteBuf byteBuf, - int flags, - final int keepaliveInterval, - final int maxLifetime, - final String metadataMimeType, - final String dataMimeType, - final ByteBuf metadata, - final ByteBuf data) { - if ((flags & FLAGS_RESUME_ENABLE) != 0) { - throw new IllegalArgumentException("RESUME_ENABLE not supported"); - } - - return encode( - byteBuf, - flags, - keepaliveInterval, - maxLifetime, - Unpooled.EMPTY_BUFFER, - metadataMimeType, - dataMimeType, - metadata, - data); - } - - // Only exposed for testing, other code shouldn't create frames with resumption tokens for now - static int encode( - final ByteBuf byteBuf, - int flags, - final int keepaliveInterval, - final int maxLifetime, - final ByteBuf resumeToken, - final String metadataMimeType, - final String dataMimeType, - final ByteBuf metadata, - final ByteBuf data) { - final int frameLength = - computeFrameLength( - flags, - resumeToken.readableBytes(), - metadataMimeType, - dataMimeType, - metadata.readableBytes(), - data.readableBytes()); - - int length = - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, flags, FrameType.SETUP, 0); - - byteBuf.setInt(VERSION_FIELD_OFFSET, CURRENT_VERSION); - byteBuf.setInt(KEEPALIVE_INTERVAL_FIELD_OFFSET, keepaliveInterval); - byteBuf.setInt(MAX_LIFETIME_FIELD_OFFSET, maxLifetime); - - length += Integer.BYTES * 3; - - if ((flags & FLAGS_RESUME_ENABLE) != 0) { - byteBuf.setShort(length, resumeToken.readableBytes()); - length += Short.BYTES; - int resumeTokenLength = resumeToken.readableBytes(); - byteBuf.setBytes(length, resumeToken, resumeTokenLength); - length += resumeTokenLength; - } - - length += putMimeType(byteBuf, length, metadataMimeType); - length += putMimeType(byteBuf, length, dataMimeType); - - length += FrameHeaderFlyweight.encodeMetadata(byteBuf, FrameType.SETUP, length, metadata); - length += FrameHeaderFlyweight.encodeData(byteBuf, length, data); - - return length; - } - - public static int version(final ByteBuf byteBuf) { - return byteBuf.getInt(VERSION_FIELD_OFFSET); - } - - public static int keepaliveInterval(final ByteBuf byteBuf) { - return byteBuf.getInt(KEEPALIVE_INTERVAL_FIELD_OFFSET); - } - - public static int maxLifetime(final ByteBuf byteBuf) { - return byteBuf.getInt(MAX_LIFETIME_FIELD_OFFSET); - } - - public static String metadataMimeType(final ByteBuf byteBuf) { - final byte[] bytes = getMimeType(byteBuf, metadataMimetypeOffset(byteBuf)); - return new String(bytes, StandardCharsets.UTF_8); - } - - public static String dataMimeType(final ByteBuf byteBuf) { - int fieldOffset = metadataMimetypeOffset(byteBuf); - - fieldOffset += 1 + byteBuf.getByte(fieldOffset); - - final byte[] bytes = getMimeType(byteBuf, fieldOffset); - return new String(bytes, StandardCharsets.UTF_8); - } - - public static int payloadOffset(final ByteBuf byteBuf) { - int fieldOffset = metadataMimetypeOffset(byteBuf); - - final int metadataMimeTypeLength = byteBuf.getByte(fieldOffset); - fieldOffset += 1 + metadataMimeTypeLength; - - final int dataMimeTypeLength = byteBuf.getByte(fieldOffset); - fieldOffset += 1 + dataMimeTypeLength; - - return fieldOffset; - } - - private static int metadataMimetypeOffset(final ByteBuf byteBuf) { - return VARIABLE_DATA_OFFSET + resumeTokenTotalLength(byteBuf); - } - - private static int resumeTokenTotalLength(final ByteBuf byteBuf) { - if ((FrameHeaderFlyweight.flags(byteBuf) & FLAGS_RESUME_ENABLE) == 0) { - return 0; - } else { - return Short.BYTES + byteBuf.getShort(VARIABLE_DATA_OFFSET); - } - } - - private static int putMimeType( - final ByteBuf byteBuf, final int fieldOffset, final String mimeType) { - byte[] bytes = mimeType.getBytes(StandardCharsets.UTF_8); - - byteBuf.setByte(fieldOffset, (byte) bytes.length); - byteBuf.setBytes(fieldOffset + 1, bytes); - - return 1 + bytes.length; - } - - private static byte[] getMimeType(final ByteBuf byteBuf, final int fieldOffset) { - final int length = byteBuf.getByte(fieldOffset); - final byte[] bytes = new byte[length]; - - byteBuf.getBytes(fieldOffset + 1, bytes); - return bytes; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java similarity index 87% rename from rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java rename to rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java index 23cff1638..35e4aa86a 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/VersionFlyweight.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,7 +16,7 @@ package io.rsocket.frame; -public class VersionFlyweight { +public class VersionCodec { public static int encode(int major, int minor) { return (major << 16) | (minor & 0xFFFF); diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java new file mode 100644 index 000000000..0d8063e0b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -0,0 +1,69 @@ +package io.rsocket.frame.decoder; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.util.DefaultPayload; +import java.nio.ByteBuffer; + +/** Default Frame decoder that copies the frames contents for easy of use. */ +class DefaultPayloadDecoder implements PayloadDecoder { + + @Override + public Payload apply(ByteBuf byteBuf) { + ByteBuf m; + ByteBuf d; + FrameType type = FrameHeaderCodec.frameType(byteBuf); + switch (type) { + case REQUEST_FNF: + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); + break; + case REQUEST_RESPONSE: + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); + break; + case REQUEST_STREAM: + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); + break; + case REQUEST_CHANNEL: + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); + break; + case NEXT: + case NEXT_COMPLETE: + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); + break; + case METADATA_PUSH: + d = Unpooled.EMPTY_BUFFER; + m = MetadataPushFrameCodec.metadata(byteBuf); + break; + default: + throw new IllegalArgumentException("unsupported frame type: " + type); + } + + ByteBuffer data = ByteBuffer.allocate(d.readableBytes()); + data.put(d.nioBuffer()); + data.flip(); + + if (m != null) { + ByteBuffer metadata = ByteBuffer.allocate(m.readableBytes()); + metadata.put(m.nioBuffer()); + metadata.flip(); + + return DefaultPayload.create(data, metadata); + } + + return DefaultPayload.create(data); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/PayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/PayloadDecoder.java new file mode 100644 index 000000000..197eca9b0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/PayloadDecoder.java @@ -0,0 +1,10 @@ +package io.rsocket.frame.decoder; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import java.util.function.Function; + +public interface PayloadDecoder extends Function { + PayloadDecoder DEFAULT = new DefaultPayloadDecoder(); + PayloadDecoder ZERO_COPY = new ZeroCopyPayloadDecoder(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java new file mode 100644 index 000000000..3a0dc7bb5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java @@ -0,0 +1,58 @@ +package io.rsocket.frame.decoder; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.util.ByteBufPayload; + +/** + * Frame decoder that decodes a frame to a payload without copying. The caller is responsible for + * for releasing the payload to free memory when they no long need it. + */ +public class ZeroCopyPayloadDecoder implements PayloadDecoder { + @Override + public Payload apply(ByteBuf byteBuf) { + ByteBuf m; + ByteBuf d; + FrameType type = FrameHeaderCodec.frameType(byteBuf); + switch (type) { + case REQUEST_FNF: + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); + break; + case REQUEST_RESPONSE: + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); + break; + case REQUEST_STREAM: + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); + break; + case REQUEST_CHANNEL: + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); + break; + case NEXT: + case NEXT_COMPLETE: + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); + break; + case METADATA_PUSH: + d = Unpooled.EMPTY_BUFFER; + m = MetadataPushFrameCodec.metadata(byteBuf); + break; + default: + throw new IllegalArgumentException("unsupported frame type: " + type); + } + + return ByteBufPayload.create(d.retain(), m != null ? m.retain() : null); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java new file mode 100644 index 000000000..82e8acaf3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi +package io.rsocket.frame.decoder; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/package-info.java b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java index d177a5eb9..69f6d6860 100644 --- a/rsocket-core/src/main/java/io/rsocket/frame/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java @@ -1,18 +1,24 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi package io.rsocket.frame; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java new file mode 100644 index 000000000..0296b0a07 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java @@ -0,0 +1,56 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +public abstract class BaseDuplexConnection implements DuplexConnection { + protected final Sinks.Empty onClose = Sinks.empty(); + protected final UnboundedProcessor sender = new UnboundedProcessor(onClose::tryEmitEmpty); + + public BaseDuplexConnection() {} + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + sender.tryEmitPrioritized(frame); + } else { + sender.tryEmitNormal(frame); + } + } + + protected abstract void doOnClose(); + + @Override + public Mono onClose() { + return onClose.asMono(); + } + + @Override + public final void dispose() { + doOnClose(); + } + + @Override + @SuppressWarnings("ConstantConditions") + public final boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java index 861a8246e..8b1378917 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -1,175 +1 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.internal; - -import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.plugins.DuplexConnectionInterceptor.Type; -import io.rsocket.plugins.PluginRegistry; -import org.reactivestreams.Publisher; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -/** - * {@link DuplexConnection#receive()} is a single stream on which the following type of frames - * arrive: - * - *

    - *
  • Frames for streams initiated by the initiator of the connection (client). - *
  • Frames for streams initiated by the acceptor of the connection (server). - *
- * - *

The only way to differentiate these two frames is determining whether the stream Id is odd or - * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the - * client. - */ -public class ClientServerInputMultiplexer { - private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); - - private final DuplexConnection streamZeroConnection; - private final DuplexConnection serverConnection; - private final DuplexConnection clientConnection; - private final DuplexConnection source; - - public ClientServerInputMultiplexer(DuplexConnection source, PluginRegistry plugins) { - this.source = source; - final MonoProcessor> streamZero = MonoProcessor.create(); - final MonoProcessor> server = MonoProcessor.create(); - final MonoProcessor> client = MonoProcessor.create(); - - source = plugins.applyConnection(Type.SOURCE, source); - streamZeroConnection = - plugins.applyConnection(Type.STREAM_ZERO, new InternalDuplexConnection(source, streamZero)); - serverConnection = - plugins.applyConnection(Type.SERVER, new InternalDuplexConnection(source, server)); - clientConnection = - plugins.applyConnection(Type.CLIENT, new InternalDuplexConnection(source, client)); - - source - .receive() - .groupBy( - frame -> { - int streamId = frame.getStreamId(); - final Type type; - if (streamId == 0) { - if (frame.getType() == FrameType.SETUP) { - type = Type.STREAM_ZERO; - } else { - type = Type.CLIENT; - } - } else if ((streamId & 0b1) == 0) { - type = Type.SERVER; - } else { - type = Type.CLIENT; - } - return type; - }) - .subscribe( - group -> { - switch (group.key()) { - case STREAM_ZERO: - streamZero.onNext(group); - break; - - case SERVER: - server.onNext(group); - break; - - case CLIENT: - client.onNext(group); - break; - } - }); - } - - public DuplexConnection asServerConnection() { - return serverConnection; - } - - public DuplexConnection asClientConnection() { - return clientConnection; - } - - public DuplexConnection asStreamZeroConnection() { - return streamZeroConnection; - } - - public Mono close() { - return source.close(); - } - - private static class InternalDuplexConnection implements DuplexConnection { - private final DuplexConnection source; - private final MonoProcessor> processor; - private final boolean debugEnabled; - - public InternalDuplexConnection(DuplexConnection source, MonoProcessor> processor) { - this.source = source; - this.processor = processor; - this.debugEnabled = LOGGER.isDebugEnabled(); - } - - @Override - public Mono send(Publisher frame) { - if (debugEnabled) { - frame = Flux.from(frame).doOnNext(f -> LOGGER.debug("sending -> " + f.toString())); - } - - return source.send(frame); - } - - @Override - public Mono sendOne(Frame frame) { - if (debugEnabled) { - LOGGER.debug("sending -> " + frame.toString()); - } - - return source.sendOne(frame); - } - - @Override - public Flux receive() { - return processor.flatMapMany( - f -> { - if (debugEnabled) { - return f.doOnNext(frame -> LOGGER.debug("receiving -> " + frame.toString())); - } else { - return f; - } - }); - } - - @Override - public Mono close() { - return source.close(); - } - - @Override - public Mono onClose() { - return source.onClose(); - } - - @Override - public double availability() { - return source.availability(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java b/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java deleted file mode 100755 index 00015321c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/internal/LimitableRequestPublisher.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import java.util.concurrent.atomic.AtomicBoolean; -import javax.annotation.Nullable; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Operators; - -/** */ -public class LimitableRequestPublisher extends Flux implements Subscription { - private final Publisher source; - - private final AtomicBoolean canceled; - - private long internalRequested; - - private long externalRequested; - - private volatile boolean subscribed; - - private volatile @Nullable Subscription internalSubscription; - - private LimitableRequestPublisher(Publisher source) { - this.source = source; - this.canceled = new AtomicBoolean(); - } - - public static LimitableRequestPublisher wrap(Publisher source) { - return new LimitableRequestPublisher<>(source); - } - - @Override - public void subscribe(CoreSubscriber destination) { - synchronized (this) { - if (subscribed) { - throw new IllegalStateException("only one subscriber at a time"); - } - - subscribed = true; - } - - destination.onSubscribe(new InnerSubscription()); - source.subscribe(new InnerSubscriber(destination)); - } - - public void increaseRequestLimit(long n) { - synchronized (this) { - externalRequested = Operators.addCap(n, externalRequested); - } - - requestN(); - } - - @Override - public void request(long n) { - increaseRequestLimit(n); - } - - private void requestN() { - long r; - synchronized (this) { - if (internalSubscription == null) { - return; - } - - r = Math.min(internalRequested, externalRequested); - externalRequested -= r; - internalRequested -= r; - } - - if (r > 0) { - internalSubscription.request(r); - } - } - - public void cancel() { - if (canceled.compareAndSet(false, true) && internalSubscription != null) { - internalSubscription.cancel(); - internalSubscription = null; - subscribed = false; - } - } - - private class InnerSubscriber implements Subscriber { - Subscriber destination; - - private InnerSubscriber(Subscriber destination) { - this.destination = destination; - } - - @Override - public void onSubscribe(Subscription s) { - synchronized (LimitableRequestPublisher.this) { - LimitableRequestPublisher.this.internalSubscription = s; - - if (canceled.get()) { - s.cancel(); - subscribed = false; - LimitableRequestPublisher.this.internalSubscription = null; - } - } - - requestN(); - } - - @Override - public void onNext(T t) { - try { - destination.onNext(t); - } catch (Throwable e) { - onError(e); - } - } - - @Override - public void onError(Throwable t) { - destination.onError(t); - } - - @Override - public void onComplete() { - destination.onComplete(); - } - } - - private class InnerSubscription implements Subscription { - @Override - public void request(long n) { - synchronized (LimitableRequestPublisher.this) { - internalRequested = Operators.addCap(n, internalRequested); - } - - requestN(); - } - - @Override - public void cancel() { - LimitableRequestPublisher.this.cancel(); - } - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java new file mode 100644 index 000000000..c96a7aed2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -0,0 +1,1167 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.stream.Stream; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.Logger; +import reactor.util.annotation.Nullable; +import reactor.util.concurrent.Queues; +import reactor.util.context.Context; + +/** + * A Processor implementation that takes a custom queue and allows only a single subscriber. + * + *

The implementation keeps the order of signals. + */ +public final class UnboundedProcessor extends Flux + implements Scannable, + Disposable, + CoreSubscriber, + Fuseable.QueueSubscription, + Fuseable { + + final Queue queue; + final Queue priorityQueue; + final Runnable onFinalizedHook; + @Nullable final Logger logger; + + boolean cancelled; + boolean done; + Throwable error; + CoreSubscriber actual; + + static final long FLAG_FINALIZED = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_DISPOSED = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_TERMINATED = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_CANCELLED = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_VALUE = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_REQUEST = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBER_READY = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBED_ONCE = + 0b0000_0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long MAX_WIP_VALUE = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "state"); + + volatile int discardGuard; + + static final AtomicIntegerFieldUpdater DISCARD_GUARD = + AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "discardGuard"); + + volatile long requested; + + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); + + ByteBuf last; + + boolean outputFused; + + public UnboundedProcessor() { + this(() -> {}); + } + + UnboundedProcessor(Logger logger) { + this(() -> {}, logger); + } + + public UnboundedProcessor(Runnable onFinalizedHook) { + this(onFinalizedHook, null); + } + + UnboundedProcessor(Runnable onFinalizedHook, @Nullable Logger logger) { + this.onFinalizedHook = onFinalizedHook; + this.queue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.priorityQueue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.logger = logger; + } + + @Override + public Stream inners() { + return hasDownstreams() ? Stream.of(Scannable.from(this.actual)) : Stream.empty(); + } + + @Override + public Object scanUnsafe(Attr key) { + if (Attr.ACTUAL == key) return isSubscriberReady(this.state) ? this.actual : null; + if (Attr.BUFFERED == key) return this.queue.size() + this.priorityQueue.size(); + if (Attr.PREFETCH == key) return Integer.MAX_VALUE; + if (Attr.CANCELLED == key) { + final long state = this.state; + return isCancelled(state) || isDisposed(state); + } + + return null; + } + + public boolean tryEmitPrioritized(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + if (!this.priorityQueue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } + + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } + + if (isWorkInProgress(previousState)) { + return true; + } + + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); + } + } + return true; + } + + public boolean tryEmitNormal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + if (!this.queue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } + + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } + + if (isWorkInProgress(previousState)) { + return true; + } + + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); + } + } + + return true; + } + + public boolean tryEmitFinal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + this.last = t; + this.done = true; + + final long previousState = markValueAddedAndTerminated(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + this.actual.onComplete(); + return true; + } + + if (isWorkInProgress(previousState)) { + return true; + } + + drainRegular((previousState | FLAG_TERMINATED | FLAG_HAS_VALUE) + 1); + } + + return true; + } + + @Deprecated + public void onNextPrioritized(ByteBuf t) { + tryEmitPrioritized(t); + } + + @Override + @Deprecated + public void onNext(ByteBuf t) { + tryEmitNormal(t); + } + + @Override + @Deprecated + public void onError(Throwable t) { + if (this.done || this.cancelled) { + Operators.onErrorDropped(t, currentContext()); + return; + } + + this.error = t; + this.done = true; + + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + Operators.onErrorDropped(t, currentContext()); + return; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onError(t); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (!hasValue(previousState)) { + // fast path no-values scenario + this.actual.onError(t); + return; + } + + drainRegular((previousState | FLAG_TERMINATED) + 1); + } + } + + @Override + @Deprecated + public void onComplete() { + if (this.done || this.cancelled) { + return; + } + + this.done = true; + + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + return; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onComplete(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (!hasValue(previousState)) { + this.actual.onComplete(); + return; + } + + drainRegular((previousState | FLAG_TERMINATED) + 1); + } + } + + void drainRegular(long expectedState) { + final CoreSubscriber a = this.actual; + final Queue q = this.queue; + final Queue pq = this.priorityQueue; + + for (; ; ) { + + long r = this.requested; + long e = 0L; + + boolean empty = false; + boolean done; + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + done = this.done; + + ByteBuf t = pq.poll(); + empty = t == null; + + if (empty) { + t = q.poll(); + empty = t == null; + } + + if (checkTerminated(done, empty, true, a)) { + if (!empty) { + release(t); + } + return; + } + + if (empty) { + break; + } + + a.onNext(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + done = this.done; + empty = q.isEmpty() && pq.isEmpty(); + + if (checkTerminated(done, empty, false, a)) { + return; + } + } + + if (e != 0 && r != Long.MAX_VALUE) { + r = REQUESTED.addAndGet(this, -e); + } + + expectedState = markWorkDone(this, expectedState, r > 0, !empty); + if (isCancelled(expectedState)) { + clearAndFinalize(this); + return; + } + + if (isDisposed(expectedState)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); + return; + } + + if (!isWorkInProgress(expectedState)) { + break; + } + } + } + + boolean checkTerminated( + boolean done, boolean empty, boolean hasDemand, CoreSubscriber a) { + final long state = this.state; + if (isCancelled(state)) { + clearAndFinalize(this); + return true; + } + + if (isDisposed(state)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); + return true; + } + + if (done && empty) { + if (!isTerminated(state)) { + // proactively return if volatile field is not yet set to needed state + return false; + } + final ByteBuf last = this.last; + if (last != null) { + if (!hasDemand) { + return false; + } + this.last = null; + a.onNext(last); + } + clearAndFinalize(this); + Throwable e = this.error; + if (e != null) { + a.onError(e); + } else { + a.onComplete(); + } + return true; + } + + return false; + } + + @Override + public void onSubscribe(Subscription s) { + final long state = this.state; + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { + s.cancel(); + } else { + s.request(Long.MAX_VALUE); + } + } + + @Override + public int getPrefetch() { + return Integer.MAX_VALUE; + } + + @Override + public Context currentContext() { + return isSubscriberReady(this.state) ? this.actual.currentContext() : Context.empty(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + long previousState = markSubscribedOnce(this); + if (isSubscribedOnce(previousState)) { + Operators.error( + actual, new IllegalStateException("UnboundedProcessor allows only a single Subscriber")); + return; + } + + if (isDisposed(previousState)) { + Operators.error(actual, new CancellationException("Disposed")); + return; + } + + actual.onSubscribe(this); + this.actual = actual; + + previousState = markSubscriberReady(this); + + if (isSubscriberReady(previousState)) { + return; + } + + if (this.outputFused) { + if (isCancelled(previousState)) { + return; + } + + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (hasValue(previousState)) { + actual.onNext(null); + } + + if (isTerminated(previousState)) { + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } + return; + } + + if (isCancelled(previousState)) { + clearAndFinalize(this); + return; + } + + if (isDisposed(previousState)) { + clearAndFinalize(this); + actual.onError(new CancellationException("Disposed")); + return; + } + + if (!hasValue(previousState)) { + if (isTerminated(previousState)) { + clearAndFinalize(this); + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } + return; + } + + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_SUBSCRIBER_READY) + 1); + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (this.outputFused) { + final long state = this.state; + if (isSubscriberReady(state)) { + this.actual.onNext(null); + } + return; + } + + Operators.addCap(REQUESTED, this, n); + + final long previousState = markRequestAdded(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (isSubscriberReady(previousState) && hasValue(previousState)) { + drainRegular((previousState | FLAG_HAS_REQUEST) + 1); + } + } + } + + @Override + public void cancel() { + this.cancelled = true; + + final long previousState = markCancelled(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState) || !this.outputFused) { + clearAndFinalize(this); + } + } + + @Override + @Deprecated + public void dispose() { + this.cancelled = true; + + final long previousState = markDisposed(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState)) { + clearAndFinalize(this); + return; + } + + if (!isSubscriberReady(previousState)) { + return; + } + + if (!this.outputFused) { + clearAndFinalize(this); + this.actual.onError(new CancellationException("Disposed")); + return; + } + + if (!isTerminated(previousState)) { + this.actual.onError(new CancellationException("Disposed")); + } + } + + @Override + @Nullable + public ByteBuf poll() { + ByteBuf t = this.priorityQueue.poll(); + if (t != null) { + return t; + } + + t = this.queue.poll(); + if (t != null) { + return t; + } + + t = this.last; + if (t != null) { + this.last = null; + return t; + } + + return null; + } + + @Override + public int size() { + return this.priorityQueue.size() + this.queue.size(); + } + + @Override + public boolean isEmpty() { + return this.priorityQueue.isEmpty() && this.queue.isEmpty(); + } + + /** + * Clears all elements from queues and set state to terminate. This method MUST be called only by + * the downstream subscriber which has enabled {@link Fuseable#ASYNC} fusion with the given {@link + * UnboundedProcessor} and is and indicator that the downstream is done with draining, it has + * observed any terminal signal (ON_COMPLETE or ON_ERROR or CANCEL) and will never be interacting + * with SingleConsumer queue anymore. + */ + @Override + public void clear() { + clearAndFinalize(this); + } + + void clearSafely() { + if (DISCARD_GUARD.getAndIncrement(this) != 0) { + return; + } + + int missed = 1; + for (; ; ) { + clearUnsafely(); + + missed = DISCARD_GUARD.addAndGet(this, -missed); + if (missed == 0) { + break; + } + } + } + + void clearUnsafely() { + final Queue queue = this.queue; + final Queue priorityQueue = this.priorityQueue; + + final ByteBuf last = this.last; + + if (last != null) { + release(last); + } + + ByteBuf byteBuf; + while ((byteBuf = queue.poll()) != null) { + release(byteBuf); + } + + while ((byteBuf = priorityQueue.poll()) != null) { + release(byteBuf); + } + } + + @Override + public int requestFusion(int requestedMode) { + if ((requestedMode & Fuseable.ASYNC) != 0) { + this.outputFused = true; + return Fuseable.ASYNC; + } + return Fuseable.NONE; + } + + @Override + public boolean isDisposed() { + return isFinalized(this.state); + } + + boolean hasDownstreams() { + final long state = this.state; + return !isTerminated(state) && isSubscriberReady(state); + } + + static void release(ByteBuf byteBuf) { + if (byteBuf.refCnt() > 0) { + try { + byteBuf.release(); + } catch (Throwable ex) { + // no ops + } + } + } + + /** + * Sets {@link #FLAG_SUBSCRIBED_ONCE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return {@code true} if {@link #FLAG_SUBSCRIBED_ONCE} was successfully set + */ + static long markSubscribedOnce(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isSubscribedOnce(state)) { + return state; + } + + final long nextState = state | FLAG_SUBSCRIBED_ONCE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mso", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_SUBSCRIBER_READY} flag if flags {@link #FLAG_FINALIZED}, {@link + * #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return previous state + */ + static long markSubscriberReady(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) + || isCancelled(state) + || isDisposed(state) + || isSubscriberReady(state)) { + return state; + } + + long nextState = state; + if (!instance.outputFused) { + if ((!hasValue(state) && isTerminated(state)) || (hasRequest(state) && hasValue(state))) { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_SUBSCRIBER_READY; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " msr", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_REQUEST} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markRequestAdded(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state) || (isSubscriberReady(state) && hasValue(state))) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_REQUEST; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mra", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAdded(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state)) { + if (instance.outputFused) { + // fast path for fusion scenario + return state; + } + + if (hasRequest(state)) { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_HAS_VALUE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mva", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAddedAndTerminated(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_VALUE | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, "mva&t", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_TERMINATED} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markTerminatedOrFinalized(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + if (!hasValue(state)) { + // fast path for no values and no work in progress + nextState = FLAG_FINALIZED; + } else { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mt|f", state, nextState); + if (isFinalized(nextState)) { + instance.onFinalizedHook.run(); + } + return state; + } + } + } + + /** + * Sets {@link #FLAG_CANCELLED} flag if it was not set before and if flag {@link #FLAG_FINALIZED} + * is unset. Also, this method increments number of work in progress (WIP) + * + * @return previous state + */ + static long markCancelled(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_CANCELLED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mc", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_DISPOSED} flag if it was not set before and if flags {@link #FLAG_FINALIZED}, + * {@link #FLAG_CANCELLED} are unset. Also, this method increments number of work in progress + * (WIP) + * + * @return previous state + */ + static long markDisposed(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_DISPOSED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " md", state, nextState); + return state; + } + } + } + + static long addWork(long state) { + return (state & MAX_WIP_VALUE) == MAX_WIP_VALUE ? state : state + 1; + } + + /** + * Decrements the amount of work in progress by the given amount on the given state. Fails if flag + * is {@link #FLAG_FINALIZED} is set or if fusion disabled and flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set. + * + *

Note, if fusion is enabled, the decrement should work if flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set, since, while the operator was not terminate by the downstream, + * we still have to propagate notifications that new elements are enqueued + * + * @return state after changing WIP or current state if update failed + */ + static long markWorkDone( + UnboundedProcessor instance, long expectedState, boolean hasRequest, boolean hasValue) { + for (; ; ) { + final long state = instance.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = + (state - (expectedState & MAX_WIP_VALUE)) + ^ (hasRequest ? 0 : FLAG_HAS_REQUEST) + ^ (hasValue ? 0 : FLAG_HAS_VALUE); + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mwd", state, nextState); + return nextState; + } + } + } + + /** + * Set flag {@link #FLAG_FINALIZED} and {@link #release(ByteBuf)} all the elements from {@link + * #queue} and {@link #priorityQueue}. + * + *

This method may be called concurrently only if the given {@link UnboundedProcessor} has no + * output fusion ({@link #outputFused} {@code == true}). Otherwise this method MUST only by the + * downstream calling method {@link #clear()} + */ + static void clearAndFinalize(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + instance.clearSafely(); + return; + } + + if (!isSubscriberReady(state) || !instance.outputFused) { + instance.clearSafely(); + } else { + instance.clearUnsafely(); + } + + long nextState = (state & ~MAX_WIP_VALUE & ~FLAG_HAS_VALUE) | FLAG_FINALIZED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " c&f", state, nextState); + instance.onFinalizedHook.run(); + break; + } + } + } + + static boolean hasValue(long state) { + return (state & FLAG_HAS_VALUE) == FLAG_HAS_VALUE; + } + + static boolean hasRequest(long state) { + return (state & FLAG_HAS_REQUEST) == FLAG_HAS_REQUEST; + } + + static boolean isCancelled(long state) { + return (state & FLAG_CANCELLED) == FLAG_CANCELLED; + } + + static boolean isDisposed(long state) { + return (state & FLAG_DISPOSED) == FLAG_DISPOSED; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WIP_VALUE) != 0; + } + + static boolean isTerminated(long state) { + return (state & FLAG_TERMINATED) == FLAG_TERMINATED; + } + + static boolean isFinalized(long state) { + return (state & FLAG_FINALIZED) == FLAG_FINALIZED; + } + + static boolean isSubscriberReady(long state) { + return (state & FLAG_SUBSCRIBER_READY) == FLAG_SUBSCRIBER_READY; + } + + static boolean isSubscribedOnce(long state) { + return (state & FLAG_SUBSCRIBED_ONCE) == FLAG_SUBSCRIBED_ONCE; + } + + static void log( + UnboundedProcessor instance, String action, long initialState, long committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + long initialState, + long committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64))); + } + } + + static void log( + UnboundedProcessor instance, String action, int initialState, int committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + int initialState, + int committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32))); + } + } + + static String formatState(long state, int size) { + final String defaultFormat = Long.toBinaryString(state); + final StringBuilder formatted = new StringBuilder(); + final int toPrepend = size - defaultFormat.length(); + for (int i = 0; i < size; i++) { + if (i != 0 && i % 4 == 0) { + formatted.append("_"); + } + if (i < toPrepend) { + formatted.append("0"); + } else { + formatted.append(defaultFormat.charAt(i - toPrepend)); + } + } + + formatted.insert(0, "0b"); + return formatted.toString(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java new file mode 100644 index 000000000..a99ef8a49 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java @@ -0,0 +1,302 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; + +import java.util.AbstractQueue; +import java.util.Iterator; + +abstract class BaseLinkedQueuePad0 extends AbstractQueue implements MessagePassingQueue { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + // * drop 8b as object header acts as padding and is >= 8b * +} + +// $gen:ordered-fields +abstract class BaseLinkedQueueProducerNodeRef extends BaseLinkedQueuePad0 { + static final long P_NODE_OFFSET = + fieldOffset(BaseLinkedQueueProducerNodeRef.class, "producerNode"); + + private volatile LinkedQueueNode producerNode; + + final void spProducerNode(LinkedQueueNode newValue) { + UNSAFE.putObject(this, P_NODE_OFFSET, newValue); + } + + final void soProducerNode(LinkedQueueNode newValue) { + UNSAFE.putOrderedObject(this, P_NODE_OFFSET, newValue); + } + + final LinkedQueueNode lvProducerNode() { + return producerNode; + } + + final boolean casProducerNode(LinkedQueueNode expect, LinkedQueueNode newValue) { + return UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, expect, newValue); + } + + final LinkedQueueNode lpProducerNode() { + return producerNode; + } +} + +abstract class BaseLinkedQueuePad1 extends BaseLinkedQueueProducerNodeRef { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseLinkedQueueConsumerNodeRef extends BaseLinkedQueuePad1 { + private static final long C_NODE_OFFSET = + fieldOffset(BaseLinkedQueueConsumerNodeRef.class, "consumerNode"); + + private LinkedQueueNode consumerNode; + + final void spConsumerNode(LinkedQueueNode newValue) { + consumerNode = newValue; + } + + @SuppressWarnings("unchecked") + final LinkedQueueNode lvConsumerNode() { + return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, C_NODE_OFFSET); + } + + final LinkedQueueNode lpConsumerNode() { + return consumerNode; + } +} + +abstract class BaseLinkedQueuePad2 extends BaseLinkedQueueConsumerNodeRef { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +/** + * A base data structure for concurrent linked queues. For convenience also pulled in common single + * consumer methods since at this time there's no plan to implement MC. + */ +abstract class BaseLinkedQueue extends BaseLinkedQueuePad2 { + + @Override + public final Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + protected final LinkedQueueNode newNode() { + return new LinkedQueueNode(); + } + + protected final LinkedQueueNode newNode(E e) { + return new LinkedQueueNode(e); + } + + /** + * {@inheritDoc}
+ * + *

IMPLEMENTATION NOTES:
+ * This is an O(n) operation as we run through all the nodes and count them.
+ * The accuracy of the value returned by this method is subject to races with producer/consumer + * threads. In particular when racing with the consumer thread this method may under estimate the + * size.
+ * + * @see java.util.Queue#size() + */ + @Override + public final int size() { + // Read consumer first, this is important because if the producer is node is 'older' than the + // consumer + // the consumer may overtake it (consume past it) invalidating the 'snapshot' notion of size. + LinkedQueueNode chaserNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + int size = 0; + // must chase the nodes all the way to the producer node, but there's no need to count beyond + // expected head. + while (chaserNode != producerNode + && // don't go passed producer node + chaserNode != null + && // stop at last node + size < Integer.MAX_VALUE) // stop at max int + { + LinkedQueueNode next; + next = chaserNode.lvNext(); + // check if this node has been consumed, if so return what we have + if (next == chaserNode) { + return size; + } + chaserNode = next; + size++; + } + return size; + } + + /** + * {@inheritDoc}
+ * + *

IMPLEMENTATION NOTES:
+ * Queue is empty when producerNode is the same as consumerNode. An alternative implementation + * would be to observe the producerNode.value is null, which also means an empty queue because + * only the consumerNode.value is allowed to be null. + * + * @see MessagePassingQueue#isEmpty() + */ + @Override + public boolean isEmpty() { + LinkedQueueNode consumerNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + return consumerNode == producerNode; + } + + protected E getSingleConsumerNodeValue( + LinkedQueueNode currConsumerNode, LinkedQueueNode nextNode) { + // we have to null out the value because we are going to hang on to the node + final E nextValue = nextNode.getAndNullValue(); + + // Fix up the next ref of currConsumerNode to prevent promoted nodes from keeping new ones + // alive. + // We use a reference to self instead of null because null is already a meaningful value (the + // next of + // producer node is null). + currConsumerNode.soNext(currConsumerNode); + spConsumerNode(nextNode); + // currConsumerNode is now no longer referenced and can be collected + return nextValue; + } + + @Override + public E relaxedPoll() { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + final LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + @Override + public E relaxedPeek() { + final LinkedQueueNode nextNode = lpConsumerNode().lvNext(); + if (nextNode != null) { + return nextNode.lpValue(); + } + return null; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public int drain(Consumer c) { + long result = 0; // use long to force safepoint into loop below + int drained; + do { + drained = drain(c, 4096); + result += drained; + } while (drained == 4096 && result <= Integer.MAX_VALUE - 4096); + return (int) result; + } + + @Override + public int drain(Consumer c, int limit) { + LinkedQueueNode chaserNode = this.lpConsumerNode(); + for (int i = 0; i < limit; i++) { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + + if (nextNode == null) { + return i; + } + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + return limit; + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + LinkedQueueNode chaserNode = this.lpConsumerNode(); + int idleCounter = 0; + while (exit.keepRunning()) { + for (int i = 0; i < 4096; i++) { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + if (nextNode == null) { + idleCounter = wait.idle(idleCounter); + continue; + } + + idleCounter = 0; + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + } + } + + @Override + public int capacity() { + return UNBOUNDED_CAPACITY; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java new file mode 100644 index 000000000..cfad5ef71 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java @@ -0,0 +1,705 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.modifiedCalcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.allocateRefArray; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.lvRefElement; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.soRefElement; + +import io.rsocket.internal.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; + +abstract class BaseMpscLinkedArrayQueuePad1 extends AbstractQueue implements IndexedQueue { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueProducerFields extends BaseMpscLinkedArrayQueuePad1 { + private static final long P_INDEX_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final boolean casProducerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad2 extends BaseMpscLinkedArrayQueueProducerFields { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueConsumerFields extends BaseMpscLinkedArrayQueuePad2 { + private static final long C_INDEX_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueConsumerFields.class, "consumerIndex"); + + private volatile long consumerIndex; + protected long consumerMask; + protected E[] consumerBuffer; + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad3 extends BaseMpscLinkedArrayQueueConsumerFields { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueColdProducerFields + extends BaseMpscLinkedArrayQueuePad3 { + private static final long P_LIMIT_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueColdProducerFields.class, "producerLimit"); + + private volatile long producerLimit; + protected long producerMask; + protected E[] producerBuffer; + + final long lvProducerLimit() { + return producerLimit; + } + + final boolean casProducerLimit(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_LIMIT_OFFSET, expect, newValue); + } + + final void soProducerLimit(long newValue) { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + +/** + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in + * linked chunks of the initial size. The queue grows only when the current buffer is full and + * elements are not copied on resize, instead a link to the new buffer is stored in the old buffer + * for the consumer to follow. + */ +abstract class BaseMpscLinkedArrayQueue extends BaseMpscLinkedArrayQueueColdProducerFields + implements MessagePassingQueue, QueueProgressIndicators { + // No post padding here, subclasses must add + private static final Object JUMP = new Object(); + private static final Object BUFFER_CONSUMED = new Object(); + private static final int CONTINUE_TO_P_INDEX_CAS = 0; + private static final int RETRY = 1; + private static final int QUEUE_FULL = 2; + private static final int QUEUE_RESIZE = 3; + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the + * chunk size. Must be 2 or more. + */ + public BaseMpscLinkedArrayQueue(final int initialCapacity) { + RangeUtil.checkGreaterThanOrEqual(initialCapacity, 2, "initialCapacity"); + + int p2capacity = Pow2.roundToPowerOfTwo(initialCapacity); + // leave lower bit of mask clear + long mask = (p2capacity - 1) << 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(p2capacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + soProducerLimit(mask); // we know it's all empty to start with + } + + @Override + public int size() { + // NOTE: because indices are on even numbers we cannot use the size util. + + /* + * It is possible for a thread to be interrupted or reschedule between the read of the producer and + * consumer indices, therefore protection is required to ensure size is within valid range. In the + * event of concurrent polls/offers to this method the size is OVER estimated as we read consumer + * index BEFORE the producer index. + */ + long after = lvConsumerIndex(); + long size; + while (true) { + final long before = after; + final long currentProducerIndex = lvProducerIndex(); + after = lvConsumerIndex(); + if (before == after) { + size = ((currentProducerIndex - after) >> 1); + break; + } + } + // Long overflow is impossible, so size is always positive. Integer overflow is possible for the + // unbounded + // indexed queues. + if (size > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) size; + } + } + + @Override + public boolean isEmpty() { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there + // is + // nothing we can do to make this an exact method. + return (this.lvConsumerIndex() == this.lvProducerIndex()); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + + long mask; + E[] buffer; + long pIndex; + + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // mask/buffer may get changed by resizing -> only use for array access after successful CAS. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) - [mask/buffer] -> cas(pIndex) + + // assumption behind this optimization is that queue is almost always empty or near empty + if (producerLimit <= pIndex) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) { + case CONTINUE_TO_P_INDEX_CAS: + break; + case RETRY: + continue; + case QUEUE_FULL: + return false; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, e, null); + return true; + } + } + + if (casProducerIndex(pIndex, pIndex + 2)) { + break; + } + } + // INDEX visible before ELEMENT + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); // release element e + return true; + } + + /** + * {@inheritDoc} + * + *

This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + if (index != lvProducerIndex()) { + // poll() == null iff queue is empty, null element is not strong enough indicator, so we + // must + // check the producer index. If the queue is indeed not empty we spin until element is + // visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } else { + return null; + } + } + + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, index); + } + + soRefElement(buffer, offset, null); // release element null + soConsumerIndex(index + 2); // release cIndex + return (E) e; + } + + /** + * {@inheritDoc} + * + *

This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == null && index != lvProducerIndex()) { + // peek() == null iff queue is empty, null element is not strong enough indicator, so we must + // check the producer index. If the queue is indeed not empty we spin until element is + // visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), index); + } + return (E) e; + } + + /** We do not inline resize into this method because we do not resize on fill. */ + private int offerSlowPath(long mask, long pIndex, long producerLimit) { + final long cIndex = lvConsumerIndex(); + long bufferCapacity = getCurrentBufferCapacity(mask); + + if (cIndex + bufferCapacity > pIndex) { + if (!casProducerLimit(producerLimit, cIndex + bufferCapacity)) { + // retry from top + return RETRY; + } else { + // continue to pIndex CAS + return CONTINUE_TO_P_INDEX_CAS; + } + } + // full and cannot grow + else if (availableInQueue(pIndex, cIndex) <= 0) { + // offer should return false; + return QUEUE_FULL; + } + // grab index for resize -> set lower bit + else if (casProducerIndex(pIndex, pIndex + 1)) { + // trigger a resize + return QUEUE_RESIZE; + } else { + // failed resize attempt, retry from top + return RETRY; + } + } + + /** @return available elements in queue * 2 */ + protected abstract long availableInQueue(long pIndex, long cIndex); + + @SuppressWarnings("unchecked") + private E[] nextBuffer(final E[] buffer, final long mask) { + final long offset = nextArrayOffset(mask); + final E[] nextBuffer = (E[]) lvRefElement(buffer, offset); + consumerBuffer = nextBuffer; + consumerMask = (length(nextBuffer) - 2) << 1; + soRefElement(buffer, offset, BUFFER_CONSUMED); + return nextBuffer; + } + + private static long nextArrayOffset(long mask) { + return modifiedCalcCircularRefElementOffset(mask + 2, Long.MAX_VALUE); + } + + private E newBufferPoll(E[] nextBuffer, long index) { + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (n == null) { + throw new IllegalStateException("new buffer must have at least one element"); + } + soRefElement(nextBuffer, offset, null); + soConsumerIndex(index + 2); + return n; + } + + private E newBufferPeek(E[] nextBuffer, long index) { + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) { + throw new IllegalStateException("new buffer must have at least one element"); + } + return n; + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex() / 2; + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex() / 2; + } + + @Override + public abstract int capacity(); + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPoll() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + return null; + } + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, index); + } + soRefElement(buffer, offset, null); + soConsumerIndex(index + 2); + return (E) e; + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPeek() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), index); + } + return (E) e; + } + + @Override + public int fill(Supplier s) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + final int capacity = capacity(); + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) throw new IllegalArgumentException("supplier is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) return 0; + + long mask; + E[] buffer; + long pIndex; + int claimedSlots; + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // NOTE: mask/buffer may get changed by resizing -> only use for array access after successful + // CAS. + // Only by virtue offloading them between the lvProducerIndex and a successful + // casProducerIndex are they + // safe to use. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) + + // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' + long batchIndex = + Math.min(producerLimit, pIndex + 2l * limit); // -> producerLimit >= batchIndex + + if (pIndex >= producerLimit) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) { + case CONTINUE_TO_P_INDEX_CAS: + // offer slow path verifies only one slot ahead, we cannot rely on indication here + case RETRY: + continue; + case QUEUE_FULL: + return 0; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, null, s); + return 1; + } + } + + // claim limit slots at once + if (casProducerIndex(pIndex, batchIndex)) { + claimedSlots = (int) ((batchIndex - pIndex) / 2); + break; + } + } + + for (int i = 0; i < claimedSlots; i++) { + final long offset = modifiedCalcCircularRefElementOffset(pIndex + 2l * i, mask); + soRefElement(buffer, offset, s.get()); + } + return claimedSlots; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + @Override + public int drain(Consumer c) { + return drain(c, capacity()); + } + + @Override + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + /** + * Get an iterator for this queue. This method is thread safe. + * + *

The iterator provides a best-effort snapshot of the elements in the queue. The returned + * iterator is not guaranteed to return elements in queue order, and races with the consumer + * thread may cause gaps in the sequence of returned elements. Like {link #relaxedPoll}, the + * iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + return new WeakIterator(consumerBuffer, lvConsumerIndex(), lvProducerIndex()); + } + + private static class WeakIterator implements Iterator { + private final long pIndex; + private long nextIndex; + private E nextElement; + private E[] currentBuffer; + private int mask; + + WeakIterator(E[] currentBuffer, long cIndex, long pIndex) { + this.pIndex = pIndex >> 1; + this.nextIndex = cIndex >> 1; + setBuffer(currentBuffer); + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + final E e = nextElement; + if (e == null) { + throw new NoSuchElementException(); + } + nextElement = getNext(); + return e; + } + + private void setBuffer(E[] buffer) { + this.currentBuffer = buffer; + this.mask = length(buffer) - 2; + } + + private E getNext() { + while (nextIndex < pIndex) { + long index = nextIndex++; + E e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } + + // not null && not JUMP -> found next element + if (e != JUMP) { + return e; + } + + // need to jump to the next buffer + int nextBufferIndex = mask + 1; + Object nextBuffer = lvRefElement(currentBuffer, calcRefElementOffset(nextBufferIndex)); + + if (nextBuffer == BUFFER_CONSUMED || nextBuffer == null) { + // Consumer may have passed us, or the next buffer is not visible yet: drop out early + return null; + } + + setBuffer((E[]) nextBuffer); + // now with the new array retry the load, it can't be a JUMP, but we need to repeat same + // index + e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } else { + return e; + } + } + return null; + } + } + + private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s) { + assert (e != null && s == null) || (e == null || s != null); + int newBufferLength = getNextBufferSize(oldBuffer); + final E[] newBuffer; + try { + newBuffer = allocateRefArray(newBufferLength); + } catch (OutOfMemoryError oom) { + assert lvProducerIndex() == pIndex + 1; + soProducerIndex(pIndex); + throw oom; + } + + producerBuffer = newBuffer; + final int newMask = (newBufferLength - 2) << 1; + producerMask = newMask; + + final long offsetInOld = modifiedCalcCircularRefElementOffset(pIndex, oldMask); + final long offsetInNew = modifiedCalcCircularRefElementOffset(pIndex, newMask); + + soRefElement(newBuffer, offsetInNew, e == null ? s.get() : e); // element in new array + soRefElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); // buffer linked + + // ASSERT code + final long cIndex = lvConsumerIndex(); + final long availableInQueue = availableInQueue(pIndex, cIndex); + RangeUtil.checkPositive(availableInQueue, "availableInQueue"); + + // Invalidate racing CASs + // We never set the limit beyond the bounds of a buffer + soProducerLimit(pIndex + Math.min(newMask, availableInQueue)); + + // make resize visible to the other producers + soProducerIndex(pIndex + 2); + + // INDEX visible before ELEMENT, consistent with consumer expectation + + // make resize visible to consumer + soRefElement(oldBuffer, offsetInOld, JUMP); + } + + /** @return next buffer size(inclusive of next array pointer) */ + protected abstract int getNextBufferSize(E[] buffer); + + /** @return current buffer capacity for elements (excluding next pointer and jump entry) * 2 */ + protected abstract long getCurrentBufferCapacity(long mask); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java new file mode 100644 index 000000000..40116bbe1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +final class IndexedQueueSizeUtil { + public static int size(IndexedQueue iq) { + /* + * It is possible for a thread to be interrupted or reschedule between the read of the producer and + * consumer indices, therefore protection is required to ensure size is within valid range. In the + * event of concurrent polls/offers to this method the size is OVER estimated as we read consumer + * index BEFORE the producer index. + */ + long after = iq.lvConsumerIndex(); + long size; + while (true) { + final long before = after; + final long currentProducerIndex = iq.lvProducerIndex(); + after = iq.lvConsumerIndex(); + if (before == after) { + size = (currentProducerIndex - after); + break; + } + } + // Long overflow is impossible (), so size is always positive. Integer overflow is possible for + // the unbounded + // indexed queues. + if (size > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) size; + } + } + + public static boolean isEmpty(IndexedQueue iq) { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there + // is + // nothing we can do to make this an exact method. + return (iq.lvConsumerIndex() == iq.lvProducerIndex()); + } + + public interface IndexedQueue { + long lvConsumerIndex(); + + long lvProducerIndex(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java new file mode 100644 index 000000000..37651f351 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ARRAY_BASE; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; + +/** This is used for method substitution in the LinkedArray classes code generation. */ +final class LinkedArrayQueueUtil { + static int length(Object[] buf) { + return buf.length; + } + + /** + * This method assumes index is actually (index << 1) because lower bit is used for resize. This + * is compensated for by reducing the element shift. The computation is constant folded, so + * there's no cost. + */ + static long modifiedCalcCircularRefElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << (REF_ELEMENT_SHIFT - 1)); + } + + static long nextArrayOffset(Object[] curr) { + return REF_ARRAY_BASE + ((long) (length(curr) - 1) << REF_ELEMENT_SHIFT); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java new file mode 100644 index 000000000..72e78bb92 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; + +final class LinkedQueueNode { + private static final long NEXT_OFFSET = fieldOffset(LinkedQueueNode.class, "next"); + + private E value; + private volatile LinkedQueueNode next; + + LinkedQueueNode() { + this(null); + } + + LinkedQueueNode(E val) { + spValue(val); + } + + /** + * Gets the current value and nulls out the reference to it from this node. + * + * @return value + */ + public E getAndNullValue() { + E temp = lpValue(); + spValue(null); + return temp; + } + + public E lpValue() { + return value; + } + + public void spValue(E newValue) { + value = newValue; + } + + public void soNext(LinkedQueueNode n) { + UNSAFE.putOrderedObject(this, NEXT_OFFSET, n); + } + + public void spNext(LinkedQueueNode n) { + UNSAFE.putObject(this, NEXT_OFFSET, n); + } + + public LinkedQueueNode lvNext() { + return next; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java new file mode 100644 index 000000000..7a0fa901f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java @@ -0,0 +1,339 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import java.util.Queue; + +/** + * Message passing queues are intended for concurrent method passing. A subset of {@link Queue} + * methods are provided with the same semantics, while further functionality which accomodates the + * concurrent usecase is also on offer. + * + *

Message passing queues provide happens before semantics to messages passed through, namely + * that writes made by the producer before offering the message are visible to the consuming thread + * after the message has been polled out of the queue. + * + * @param the event/message type + */ +public interface MessagePassingQueue { + int UNBOUNDED_CAPACITY = -1; + + interface Supplier { + /** + * This method will return the next value to be written to the queue. As such the queue + * implementations are commited to insert the value once the call is made. + * + *

Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the supplier + * method. In particular size and any offer methods may take the view that the full batch has + * already happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + *

WARNING: this method is assumed to never return {@code null}. Breaking this + * assumption can lead to a broken queue. + * + * @return new element, NEVER {@code null} + */ + T get(); + } + + interface Consumer { + /** + * This method will process an element already removed from the queue. This method is expected + * to never throw an exception. + * + *

Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the accept method. + * In particular size and any poll/peek methods may take the view that the full batch has + * already happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + * @param e not {@code null} + */ + void accept(T e); + } + + interface WaitStrategy { + /** + * This method can implement static or dynamic backoff. Dynamic backoff will rely on the counter + * for estimating how long the caller has been idling. The expected usage is: + * + *

+ * + *

+     * 
+     * int ic = 0;
+     * while(true) {
+     *   if(!isGodotArrived()) {
+     *     ic = w.idle(ic);
+     *     continue;
+     *   }
+     *   ic = 0;
+     *   // party with Godot until he goes again
+     * }
+     * 
+     * 
+ * + * @param idleCounter idle calls counter, managed by the idle method until reset + * @return new counter value to be used on subsequent idle cycle + */ + int idle(int idleCounter); + } + + interface ExitCondition { + + /** + * This method should be implemented such that the flag read or determination cannot be hoisted + * out of a loop which notmally means a volatile load, but with JDK9 VarHandles may mean + * getOpaque. + * + * @return true as long as we should keep running + */ + boolean keepRunning(); + } + + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation and + * according to the {@link Queue#offer(Object)} interface. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false iff full + */ + boolean offer(T e); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#poll()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ + T poll(); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#peek()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ + T peek(); + + /** + * This method's accuracy is subject to concurrent modifications happening as the size is + * estimated and as such is a best effort rather than absolute value. For some implementations + * this method may be O(n) rather than O(1). + * + * @return number of messages in the queue, between 0 and {@link Integer#MAX_VALUE} but less or + * equals to capacity (if bounded). + */ + int size(); + + /** + * Removes all items from the queue. Called from the consumer thread subject to the restrictions + * appropriate to the implementation and according to the {@link Queue#clear()} interface. + */ + void clear(); + + /** + * This method's accuracy is subject to concurrent modifications happening as the observation is + * carried out. + * + * @return true if empty, false otherwise + */ + boolean isEmpty(); + + /** + * @return the capacity of this queue or {@link MessagePassingQueue#UNBOUNDED_CAPACITY} if not + * bounded + */ + int capacity(); + + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation. As + * opposed to {@link Queue#offer(Object)} this method may return false without the queue being + * full. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false if unable to offer + */ + boolean relaxedOffer(T e); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#poll()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to poll + */ + T relaxedPoll(); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#peek()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to peek + */ + T relaxedPeek(); + + /** + * Remove up to limit elements from the queue and hand to consume. This should be + * semantically similar to: + * + *

+ * + *

{@code
+   * M m;
+   * int i = 0;
+   * for(;i < limit && (m = relaxedPoll()) != null; i++){
+   *   c.accept(m);
+   * }
+   * return i;
+   * }
+ * + *

There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + * @throws IllegalArgumentException if limit is negative + */ + int drain(Consumer c, int limit); + + /** + * Stuff the queue with up to limit elements from the supplier. Semantically similar to: + * + *

+ * + *

{@code
+   * for(int i=0; i < limit && relaxedOffer(s.get()); i++);
+   * }
+ * + *

There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + * @throws IllegalArgumentException if limit is negative + */ + int fill(Supplier s, int limit); + + /** + * Remove all available item from the queue and hand to consume. This should be semantically + * similar to: + * + *

+   * M m;
+   * while((m = relaxedPoll()) != null){
+   * c.accept(m);
+   * }
+   * 
+ * + * There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + */ + int drain(Consumer c); + + /** + * Stuff the queue with elements from the supplier. Semantically similar to: + * + *

+   * while(relaxedOffer(s.get());
+   * 
+ * + * There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

Unbounded queues will fill up the queue with a fixed amount rather than fill up to oblivion. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + */ + int fill(Supplier s); + + /** + * Remove elements from the queue and hand to consume forever. Semantically similar to: + * + *

+ * + *

+   *  int idleCounter = 0;
+   *  while (exit.keepRunning()) {
+   *      E e = relaxedPoll();
+   *      if(e==null){
+   *          idleCounter = wait.idle(idleCounter);
+   *          continue;
+   *      }
+   *      idleCounter = 0;
+   *      c.accept(e);
+   *  }
+   * 
+ * + *

Called from a consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @throws IllegalArgumentException c OR wait OR exit are {@code null} + */ + void drain(Consumer c, WaitStrategy wait, ExitCondition exit); + + /** + * Stuff the queue with elements from the supplier forever. Semantically similar to: + * + *

+ * + *

+   * 
+   *  int idleCounter = 0;
+   *  while (exit.keepRunning()) {
+   *      E e = s.get();
+   *      while (!relaxedOffer(e)) {
+   *          idleCounter = wait.idle(idleCounter);
+   *          continue;
+   *      }
+   *      idleCounter = 0;
+   *  }
+   * 
+   * 
+ * + *

Called from a producer thread subject to the restrictions appropriate to the implementation. + * The main difference being that implementors MUST assure room in the queue is available BEFORE + * calling {@link Supplier#get}. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @throws IllegalArgumentException s OR wait OR exit are {@code null} + */ + void fill(Supplier s, WaitStrategy wait, ExitCondition exit); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java new file mode 100644 index 000000000..cb03364d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.jctools.queues; + +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Consumer; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.ExitCondition; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Supplier; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.WaitStrategy; + +final class MessagePassingQueueUtil { + public static int drain(MessagePassingQueue queue, Consumer c, int limit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) return 0; + E e; + int i = 0; + for (; i < limit && (e = queue.relaxedPoll()) != null; i++) { + c.accept(e); + } + return i; + } + + public static int drain(MessagePassingQueue queue, Consumer c) { + if (null == c) throw new IllegalArgumentException("c is null"); + E e; + int i = 0; + while ((e = queue.relaxedPoll()) != null) { + i++; + c.accept(e); + } + return i; + } + + public static void drain( + MessagePassingQueue queue, Consumer c, WaitStrategy wait, ExitCondition exit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (null == wait) throw new IllegalArgumentException("wait is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + final E e = queue.relaxedPoll(); + if (e == null) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + c.accept(e); + } + } + + public static void fill( + MessagePassingQueue q, Supplier s, WaitStrategy wait, ExitCondition exit) { + if (null == wait) throw new IllegalArgumentException("waiter is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + if (q.fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + } + } + + public static int fillBounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, q.capacity()); + } + + public static int fillInBatchesToLimit( + MessagePassingQueue q, Supplier s, int batch, int limit) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + do { + final int filled = q.fill(s, batch); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= limit); + return (int) result; + } + + public static int fillUnbounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, 4096); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java new file mode 100644 index 000000000..179070be4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; +import static io.rsocket.internal.jctools.queues.MessagePassingQueueUtil.fillUnbounded; + +/** + * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked + * chunks of the initial size. The queue grows only when the current chunk is full and elements are + * not copied on resize, instead a link to the new chunk is stored in the old chunk for the consumer + * to follow. + */ +public class MpscUnboundedArrayQueue extends BaseMpscLinkedArrayQueue { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b + + public MpscUnboundedArrayQueue(int chunkSize) { + super(chunkSize); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) { + return Integer.MAX_VALUE; + } + + @Override + public int capacity() { + return MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public int drain(Consumer c) { + return drain(c, 4096); + } + + @Override + public int fill(Supplier s) { + return fillUnbounded(this, s); + } + + @Override + protected int getNextBufferSize(E[] buffer) { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return mask; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java new file mode 100644 index 000000000..f037857e8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +/** JVM Information that is standard and available on all JVMs (i.e. does not use unsafe) */ +interface PortableJvmInfo { + int CACHE_LINE_SIZE = Integer.getInteger("jctools.cacheLineSize", 64); + int CPUs = Runtime.getRuntime().availableProcessors(); + int RECOMENDED_OFFER_BATCH = CPUs * 4; + int RECOMENDED_POLL_BATCH = CPUs * 4; +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java new file mode 100644 index 000000000..282a22f02 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +/** Power of 2 utility functions. */ +final class Pow2 { + public static final int MAX_POW2 = 1 << 30; + + /** + * @param value from which next positive power of two will be found. + * @return the next positive power of 2, this value if it is a power of 2. Negative values are + * mapped to 1. + * @throws IllegalArgumentException is value is more than MAX_POW2 or less than 0 + */ + public static int roundToPowerOfTwo(final int value) { + if (value > MAX_POW2) { + throw new IllegalArgumentException( + "There is no larger power of 2 int for value:" + value + " since it exceeds 2^31."); + } + if (value < 0) { + throw new IllegalArgumentException("Given value:" + value + ". Expecting value >= 0."); + } + final int nextPow2 = 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + return nextPow2; + } + + /** + * @param value to be tested to see if it is a power of two. + * @return true if the value is a power of 2 otherwise false. + */ + public static boolean isPowerOfTwo(final int value) { + return (value & (value - 1)) == 0; + } + + /** + * Align a value to the next multiple up of alignment. If the value equals an alignment multiple + * then it is returned unchanged. + * + * @param value to be aligned up. + * @param alignment to be used, must be a power of 2. + * @return the value aligned to the next boundary. + */ + public static long align(final long value, final int alignment) { + if (!isPowerOfTwo(alignment)) { + throw new IllegalArgumentException("alignment must be a power of 2:" + alignment); + } + return (value + (alignment - 1)) & ~(alignment - 1); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java new file mode 100644 index 000000000..6418cc947 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +/** + * This interface is provided for monitoring purposes only and is only available on queues where it + * is easy to provide it. The producer/consumer progress indicators usually correspond with the + * number of elements offered/polled, but they are not guaranteed to maintain that semantic. + * + * @author nitsanw + */ +public interface QueueProgressIndicators { + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under + * normal circumstances 2 consecutive calls to this method can offer an idea of progress made by + * producer threads by subtracting the 2 results though in extreme cases (if producers have + * progressed by more than 2^64) this may also fail.
+ * This value will normally indicate number of elements passed into the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the producer progress index + */ + long currentProducerIndex(); + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under + * normal circumstances 2 consecutive calls to this method can offer an idea of progress made by + * consumer threads by subtracting the 2 results though in extreme cases (if consumers have + * progressed by more than 2^64) this may also fail.
+ * This value will normally indicate number of elements taken out of the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the consumer progress index + */ + long currentConsumerIndex(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java new file mode 100644 index 000000000..3adcb2f3c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +final class RangeUtil { + public static long checkPositive(long n, String name) { + if (n <= 0) { + throw new IllegalArgumentException(name + ": " + n + " (expected: > 0)"); + } + + return n; + } + + public static int checkPositiveOrZero(int n, String name) { + if (n < 0) { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= 0)"); + } + + return n; + } + + public static int checkLessThan(int n, int expected, String name) { + if (n >= expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: < " + expected + ')'); + } + + return n; + } + + public static int checkLessThanOrEqual(int n, long expected, String name) { + if (n > expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: <= " + expected + ')'); + } + + return n; + } + + public static int checkGreaterThanOrEqual(int n, int expected, String name) { + if (n < expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= " + expected + ')'); + } + + return n; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java new file mode 100644 index 000000000..c99aeb689 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import sun.misc.Unsafe; + +/** + * Why should we resort to using Unsafe?
+ * + *

    + *
  1. To construct class fields which allow volatile/ordered/plain access: This requirement is + * covered by {@link AtomicReferenceFieldUpdater} and similar but their performance is + * arguably worse than the DIY approach (depending on JVM version) while Unsafe + * intrinsification is a far lesser challenge for JIT compilers. + *
  2. To construct flavors of {@link AtomicReferenceArray}. + *
  3. Other use cases exist but are not present in this library yet. + *
+ * + * @author nitsanw + */ +class UnsafeAccess { + public static final boolean SUPPORTS_GET_AND_SET_REF; + public static final boolean SUPPORTS_GET_AND_ADD_LONG; + public static final Unsafe UNSAFE; + + static { + UNSAFE = getUnsafe(); + SUPPORTS_GET_AND_SET_REF = hasGetAndSetSupport(); + SUPPORTS_GET_AND_ADD_LONG = hasGetAndAddLongSupport(); + } + + private static Unsafe getUnsafe() { + Unsafe instance; + try { + final Field field = Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + instance = (Unsafe) field.get(null); + } catch (Exception ignored) { + // Some platforms, notably Android, might not have a sun.misc.Unsafe implementation with a + // private + // `theUnsafe` static instance. In this case we can try to call the default constructor, which + // is sufficient + // for Android usage. + try { + Constructor c = Unsafe.class.getDeclaredConstructor(); + c.setAccessible(true); + instance = c.newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return instance; + } + + private static boolean hasGetAndSetSupport() { + try { + Unsafe.class.getMethod("getAndSetObject", Object.class, Long.TYPE, Object.class); + return true; + } catch (Exception ignored) { + } + return false; + } + + private static boolean hasGetAndAddLongSupport() { + try { + Unsafe.class.getMethod("getAndAddLong", Object.class, Long.TYPE, Long.TYPE); + return true; + } catch (Exception ignored) { + } + return false; + } + + public static long fieldOffset(Class clz, String fieldName) throws RuntimeException { + try { + return UNSAFE.objectFieldOffset(clz.getDeclaredField(fieldName)); + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java new file mode 100644 index 000000000..c734a9914 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; + +final class UnsafeRefArrayAccess { + public static final long REF_ARRAY_BASE; + public static final int REF_ELEMENT_SHIFT; + + static { + final int scale = UNSAFE.arrayIndexScale(Object[].class); + if (4 == scale) { + REF_ELEMENT_SHIFT = 2; + } else if (8 == scale) { + REF_ELEMENT_SHIFT = 3; + } else { + throw new IllegalStateException("Unknown pointer size: " + scale); + } + REF_ARRAY_BASE = UNSAFE.arrayBaseOffset(Object[].class); + } + + /** + * A plain store (no ordering/fences) of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @param e an orderly kitty + */ + public static void spRefElement(E[] buffer, long offset, E e) { + UNSAFE.putObject(buffer, offset, e); + } + + /** + * An ordered store of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcCircularRefElementOffset} + * @param e an orderly kitty + */ + public static void soRefElement(E[] buffer, long offset, E e) { + UNSAFE.putOrderedObject(buffer, offset, e); + } + + /** + * A plain load (no ordering/fences) of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lpRefElement(E[] buffer, long offset) { + return (E) UNSAFE.getObject(buffer, offset); + } + + /** + * A volatile load of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lvRefElement(E[] buffer, long offset) { + return (E) UNSAFE.getObjectVolatile(buffer, offset); + } + + /** + * @param index desirable element index + * @return the offset in bytes within the array for a given index + */ + public static long calcRefElementOffset(long index) { + return REF_ARRAY_BASE + (index << REF_ELEMENT_SHIFT); + } + + /** + * Note: circular arrays are assumed a power of 2 in length and the `mask` is (length - 1). + * + * @param index desirable element index + * @param mask (length - 1) + * @return the offset in bytes within the circular array for a given index + */ + public static long calcCircularRefElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); + } + + /** This makes for an easier time generating the atomic queues, and removes some warnings. */ + @SuppressWarnings("unchecked") + public static E[] allocateRefArray(int capacity) { + return (E[]) new Object[capacity]; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/package-info.java b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java index c1ed71ce8..07ddfab41 100644 --- a/rsocket-core/src/main/java/io/rsocket/internal/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,5 +18,7 @@ * Internal package and must not be used outside this project. There are no guarantees for * API compatibility. */ -@javax.annotation.ParametersAreNonnullByDefault +@NonNullApi package io.rsocket.internal; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java new file mode 100644 index 000000000..8fb918dc6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java @@ -0,0 +1,9 @@ +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import reactor.core.Disposable; + +public interface KeepAliveFramesAcceptor extends Disposable { + + void receive(ByteBuf keepAliveFrame); +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java new file mode 100644 index 000000000..4fd7a772d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java @@ -0,0 +1,60 @@ +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import io.rsocket.keepalive.KeepAliveSupport.KeepAlive; +import io.rsocket.resume.RSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; +import java.util.function.Consumer; + +public interface KeepAliveHandler { + + KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onFrameSent, + Consumer onTimeout); + + class DefaultKeepAliveHandler implements KeepAliveHandler { + @Override + public KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onSendKeepAliveFrame, + Consumer onTimeout) { + return keepAliveSupport + .onSendKeepAliveFrame(onSendKeepAliveFrame) + .onTimeout(onTimeout) + .start(); + } + } + + class ResumableKeepAliveHandler implements KeepAliveHandler { + + private final ResumableDuplexConnection resumableDuplexConnection; + private final RSocketSession rSocketSession; + private final ResumeStateHolder resumeStateHolder; + + public ResumableKeepAliveHandler( + ResumableDuplexConnection resumableDuplexConnection, + RSocketSession rSocketSession, + ResumeStateHolder resumeStateHolder) { + this.resumableDuplexConnection = resumableDuplexConnection; + this.rSocketSession = rSocketSession; + this.resumeStateHolder = resumeStateHolder; + } + + @Override + public KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onSendKeepAliveFrame, + Consumer onTimeout) { + + rSocketSession.setKeepAliveSupport(keepAliveSupport); + + return keepAliveSupport + .resumeState(resumeStateHolder) + .onSendKeepAliveFrame(onSendKeepAliveFrame) + .onTimeout(keepAlive -> resumableDuplexConnection.disconnect()) + .start(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java new file mode 100644 index 000000000..4fd18d041 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java @@ -0,0 +1,201 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.resume.ResumeStateHolder; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.Consumer; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public abstract class KeepAliveSupport implements KeepAliveFramesAcceptor { + + final ByteBufAllocator allocator; + final Scheduler scheduler; + final Duration keepAliveInterval; + final Duration keepAliveTimeout; + final long keepAliveTimeoutMillis; + + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(KeepAliveSupport.class, "state"); + + static final int STOPPED_STATE = 0; + static final int STARTING_STATE = 1; + static final int STARTED_STATE = 2; + static final int DISPOSED_STATE = -1; + + volatile Consumer onTimeout; + volatile Consumer onFrameSent; + + Disposable ticksDisposable; + + volatile ResumeStateHolder resumeStateHolder; + volatile long lastReceivedMillis; + + private KeepAliveSupport( + ByteBufAllocator allocator, int keepAliveInterval, int keepAliveTimeout) { + this.allocator = allocator; + this.scheduler = Schedulers.parallel(); + this.keepAliveInterval = Duration.ofMillis(keepAliveInterval); + this.keepAliveTimeout = Duration.ofMillis(keepAliveTimeout); + this.keepAliveTimeoutMillis = keepAliveTimeout; + } + + public KeepAliveSupport start() { + if (this.state == STOPPED_STATE && STATE.compareAndSet(this, STOPPED_STATE, STARTING_STATE)) { + this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); + + final Disposable disposable = + Flux.interval(keepAliveInterval, scheduler).subscribe(v -> onIntervalTick()); + this.ticksDisposable = disposable; + + if (this.state != STARTING_STATE + || !STATE.compareAndSet(this, STARTING_STATE, STARTED_STATE)) { + disposable.dispose(); + } + } + return this; + } + + public void stop() { + terminate(STOPPED_STATE); + } + + @Override + public void receive(ByteBuf keepAliveFrame) { + this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); + if (resumeStateHolder != null) { + final long remoteLastReceivedPos = KeepAliveFrameCodec.lastPosition(keepAliveFrame); + resumeStateHolder.onImpliedPosition(remoteLastReceivedPos); + } + if (KeepAliveFrameCodec.respondFlag(keepAliveFrame)) { + long localLastReceivedPos = localLastReceivedPosition(); + send( + KeepAliveFrameCodec.encode( + allocator, + false, + localLastReceivedPos, + KeepAliveFrameCodec.data(keepAliveFrame).retain())); + } + } + + public KeepAliveSupport resumeState(ResumeStateHolder resumeStateHolder) { + this.resumeStateHolder = resumeStateHolder; + return this; + } + + public KeepAliveSupport onSendKeepAliveFrame(Consumer onFrameSent) { + this.onFrameSent = onFrameSent; + return this; + } + + public KeepAliveSupport onTimeout(Consumer onTimeout) { + this.onTimeout = onTimeout; + return this; + } + + @Override + public void dispose() { + terminate(DISPOSED_STATE); + } + + @Override + public boolean isDisposed() { + return ticksDisposable.isDisposed(); + } + + abstract void onIntervalTick(); + + void send(ByteBuf frame) { + if (onFrameSent != null) { + onFrameSent.accept(frame); + } + } + + void tryTimeout() { + long now = scheduler.now(TimeUnit.MILLISECONDS); + if (now - lastReceivedMillis >= keepAliveTimeoutMillis) { + if (onTimeout != null) { + onTimeout.accept(new KeepAlive(keepAliveInterval, keepAliveTimeout)); + } + stop(); + } + } + + void terminate(int terminationState) { + for (; ; ) { + final int state = this.state; + + if (state == STOPPED_STATE || state == DISPOSED_STATE) { + return; + } + + final Disposable disposable = this.ticksDisposable; + if (STATE.compareAndSet(this, state, terminationState)) { + disposable.dispose(); + return; + } + } + } + + long localLastReceivedPosition() { + return resumeStateHolder != null ? resumeStateHolder.impliedPosition() : 0; + } + + public static final class ClientKeepAliveSupport extends KeepAliveSupport { + + public ClientKeepAliveSupport( + ByteBufAllocator allocator, int keepAliveInterval, int keepAliveTimeout) { + super(allocator, keepAliveInterval, keepAliveTimeout); + } + + @Override + void onIntervalTick() { + tryTimeout(); + send( + KeepAliveFrameCodec.encode( + allocator, true, localLastReceivedPosition(), Unpooled.EMPTY_BUFFER)); + } + } + + public static final class KeepAlive { + private final Duration tickPeriod; + private final Duration timeoutMillis; + + public KeepAlive(Duration tickPeriod, Duration timeoutMillis) { + this.tickPeriod = tickPeriod; + this.timeoutMillis = timeoutMillis; + } + + public Duration getTickPeriod() { + return tickPeriod; + } + + public Duration getTimeout() { + return timeoutMillis; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java new file mode 100644 index 000000000..d94a93cad --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Support classes for sending and keeping track of KEEPALIVE frames from the remote. */ +@NonNullApi +package io.rsocket.keepalive; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java index 416bc1998..9e76d176d 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,32 +16,74 @@ package io.rsocket.lease; -import java.nio.ByteBuffer; -import javax.annotation.Nullable; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import java.time.Duration; +import reactor.util.annotation.Nullable; /** A contract for RSocket lease, which is sent by a request acceptor and is time bound. */ -public interface Lease { +public final class Lease { + + public static Lease create( + Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + return new Lease(timeToLive, numberOfRequests, metadata); + } + + public static Lease create(Duration timeToLive, int numberOfRequests) { + return create(timeToLive, numberOfRequests, Unpooled.EMPTY_BUFFER); + } + + public static Lease unbounded() { + return unbounded(null); + } + + public static Lease unbounded(@Nullable ByteBuf metadata) { + return create(Duration.ofMillis(Integer.MAX_VALUE), Integer.MAX_VALUE, metadata); + } + + public static Lease empty() { + return create(Duration.ZERO, 0); + } + + final int timeToLiveMillis; + final int numberOfRequests; + final ByteBuf metadata; + final long expirationTime; + + Lease(Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + this.numberOfRequests = numberOfRequests; + this.timeToLiveMillis = (int) Math.min(timeToLive.toMillis(), Integer.MAX_VALUE); + this.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + this.expirationTime = + timeToLive.isZero() ? 0 : System.currentTimeMillis() + timeToLive.toMillis(); + } /** * Number of requests allowed by this lease. * * @return The number of requests allowed by this lease. */ - int getAllowedRequests(); + public int numberOfRequests() { + return numberOfRequests; + } /** - * Number of seconds that this lease is valid from the time it is received. + * Time to live for the given lease * - * @return Number of seconds that this lease is valid from the time it is received. + * @return relative duration in milliseconds */ - int getTtl(); + public int timeToLiveInMillis() { + return this.timeToLiveMillis; + } /** * Absolute time since epoch at which this lease will expire. * * @return Absolute time since epoch at which this lease will expire. */ - long expiry(); + public long expirationTime() { + return expirationTime; + } /** * Metadata for the lease. @@ -49,24 +91,19 @@ public interface Lease { * @return Metadata for the lease. */ @Nullable - ByteBuffer getMetadata(); - - /** - * Checks if the lease is expired now. - * - * @return {@code true} if the lease has expired. - */ - default boolean isExpired() { - return isExpired(System.currentTimeMillis()); + public ByteBuf metadata() { + return metadata; } - /** - * Checks if the lease is expired for the passed {@code now}. - * - * @param now current time in millis. - * @return {@code true} if the lease has expired. - */ - default boolean isExpired(long now) { - return now > expiry(); + @Override + public String toString() { + return "Lease{" + + "timeToLiveMillis=" + + timeToLiveMillis + + ", numberOfRequests=" + + numberOfRequests + + ", expirationTime=" + + expirationTime + + '}'; } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java deleted file mode 100644 index 0f99092fb..000000000 --- a/rsocket-core/src/main/java/io/rsocket/lease/LeaseImpl.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.lease; - -import io.rsocket.Frame; -import java.nio.ByteBuffer; -import javax.annotation.Nullable; - -public final class LeaseImpl implements Lease { - - private final int allowedRequests; - private final int ttl; - private final long expiry; - private final @Nullable ByteBuffer metadata; - - public LeaseImpl(int allowedRequests, int ttl) { - this(allowedRequests, ttl, null); - } - - public LeaseImpl(int allowedRequests, int ttl, ByteBuffer metadata) { - this.allowedRequests = allowedRequests; - this.ttl = ttl; - expiry = System.currentTimeMillis() + ttl; - this.metadata = metadata; - } - - public LeaseImpl(Frame leaseFrame) { - this( - Frame.Lease.numberOfRequests(leaseFrame), - Frame.Lease.ttl(leaseFrame), - leaseFrame.getMetadata()); - } - - @Override - public int getAllowedRequests() { - return allowedRequests; - } - - @Override - public int getTtl() { - return ttl; - } - - @Override - public long expiry() { - return expiry; - } - - @Override - public ByteBuffer getMetadata() { - return metadata; - } - - @Override - public String toString() { - return "LeaseImpl{" - + "allowedRequests=" - + allowedRequests - + ", ttl=" - + ttl - + ", expiry=" - + expiry - + '}'; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java new file mode 100644 index 000000000..48bd38494 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java @@ -0,0 +1,8 @@ +package io.rsocket.lease; + +import reactor.core.publisher.Flux; + +public interface LeaseSender { + + Flux send(); +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SecureWebsocketClientServerTest.java b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java similarity index 53% rename from rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SecureWebsocketClientServerTest.java rename to rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java index ff787dc90..84af91b1b 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SecureWebsocketClientServerTest.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,14 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.transport.netty; +package io.rsocket.lease; -import io.rsocket.test.BaseClientServerTest; +import io.rsocket.exceptions.RejectedException; + +public class MissingLeaseException extends RejectedException { + private static final long serialVersionUID = -6169748673403858959L; + + public MissingLeaseException(String message) { + super(message); + } -public class SecureWebsocketClientServerTest - extends BaseClientServerTest { @Override - protected SecureWebsocketClientSetupRule createClientServer() { - return new SecureWebsocketClientSetupRule(); + public synchronized Throwable fillInStackTrace() { + return this; } } diff --git a/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java new file mode 100644 index 000000000..3e6f68321 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java @@ -0,0 +1,5 @@ +package io.rsocket.lease; + +import io.rsocket.plugins.RequestInterceptor; + +public interface TrackingLeaseSender extends LeaseSender, RequestInterceptor {} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/package-info.java b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java index c0109b3e5..342ab27f7 100644 --- a/rsocket-core/src/main/java/io/rsocket/lease/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java @@ -1,18 +1,27 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** + * Contains support classes for the Lease feature of the RSocket protocol. + * + * @see Resuming + * Operation + */ +@NonNullApi package io.rsocket.lease; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java new file mode 100644 index 000000000..fdbbeb25d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java @@ -0,0 +1,235 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Implementation of {@link WeightedStats} that manages tracking state and exposes the required + * stats. + * + *

A sub-class or a different class (delegation) needs to call {@link #startStream()}, {@link + * #stopStream()}, {@link #startRequest()}, and {@link #stopRequest(long)} to drive state tracking. + * + * @since 1.1 + * @see WeightedStatsRequestInterceptor + */ +public class BaseWeightedStats implements WeightedStats { + + private static final double DEFAULT_LOWER_QUANTILE = 0.5; + private static final double DEFAULT_HIGHER_QUANTILE = 0.8; + private static final int INACTIVITY_FACTOR = 500; + private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = + Clock.unit().convert(1L, TimeUnit.SECONDS); + + private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; + + private final Quantile lowerQuantile; + private final Quantile higherQuantile; + private final Ewma availabilityPercentage; + private final Median median; + private final Ewma interArrivalTime; + + private final long tau; + private final long inactivityFactor; + + private long errorStamp; // last we got an error + private long stamp; // last timestamp we sent a request + private long stamp0; // last timestamp we sent a request or receive a response + private long duration; // instantaneous cumulative duration + + private volatile int pendingRequests; // instantaneous rate + private static final AtomicIntegerFieldUpdater PENDING_REQUESTS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingRequests"); + private volatile int pendingStreams; // number of active streams + private static final AtomicIntegerFieldUpdater PENDING_STREAMS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingStreams"); + + protected BaseWeightedStats() { + this( + new FrugalQuantile(DEFAULT_LOWER_QUANTILE), + new FrugalQuantile(DEFAULT_HIGHER_QUANTILE), + INACTIVITY_FACTOR); + } + + private BaseWeightedStats( + Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { + this.lowerQuantile = lowerQuantile; + this.higherQuantile = higherQuantile; + this.inactivityFactor = inactivityFactor; + + long now = Clock.now(); + this.stamp = now; + this.errorStamp = now; + this.stamp0 = now; + this.duration = 0L; + this.pendingRequests = 0; + this.median = new Median(); + this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); + this.availabilityPercentage = new Ewma(5, TimeUnit.SECONDS, 1.0); + this.tau = Clock.unit().convert((long) (5 / Math.log(2)), TimeUnit.SECONDS); + } + + @Override + public double lowerQuantileLatency() { + return lowerQuantile.estimation(); + } + + @Override + public double higherQuantileLatency() { + return higherQuantile.estimation(); + } + + @Override + public int pending() { + return pendingRequests + pendingStreams; + } + + @Override + public double weightedAvailability() { + if (Clock.now() - stamp > tau) { + updateAvailability(1.0); + } + return availabilityPercentage.value(); + } + + @Override + public double predictedLatency() { + final long now = Clock.now(); + final long elapsed; + + synchronized (this) { + elapsed = Math.max(now - stamp, 1L); + } + + final double latency; + final double prediction = median.estimation(); + + final int pending = this.pending(); + if (prediction == 0.0) { + if (pending == 0) { + latency = 0.0; // first request + } else { + // subsequent requests while we don't have any history + latency = STARTUP_PENALTY + pending; + } + } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { + // if we did't see any data for a while, we decay the prediction by inserting + // artificial 0.0 into the median + median.insert(0.0); + latency = median.estimation(); + } else { + final double predicted = prediction * pending; + final double instant = instantaneous(now, pending); + + if (predicted < instant) { // NB: (0.0 < 0.0) == false + latency = instant / pending; // NB: pending never equal 0 here + } else { + // we are under the predictions + latency = prediction; + } + } + + return latency; + } + + long instantaneous(long now, int pending) { + return duration + (now - stamp0) * pending; + } + + void startStream() { + PENDING_STREAMS.incrementAndGet(this); + } + + void stopStream() { + PENDING_STREAMS.decrementAndGet(this); + } + + synchronized long startRequest() { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + interArrivalTime.insert(now - stamp); + duration += Math.max(0, now - stamp0) * pendingRequests; + PENDING_REQUESTS.lazySet(this, pendingRequests + 1); + stamp = now; + stamp0 = now; + + return now; + } + + synchronized long stopRequest(long timestamp) { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + duration += Math.max(0, now - stamp0) * pendingRequests - (now - timestamp); + PENDING_REQUESTS.lazySet(this, pendingRequests - 1); + stamp0 = now; + + return now; + } + + synchronized void record(double roundTripTime) { + median.insert(roundTripTime); + lowerQuantile.insert(roundTripTime); + higherQuantile.insert(roundTripTime); + } + + void updateAvailability(double value) { + availabilityPercentage.insert(value); + if (value == 0.0d) { + synchronized (this) { + errorStamp = Clock.now(); + } + } + } + + @Override + public String toString() { + return "Stats{" + + "lowerQuantile=" + + lowerQuantile.estimation() + + ", higherQuantile=" + + higherQuantile.estimation() + + ", inactivityFactor=" + + inactivityFactor + + ", tau=" + + tau + + ", errorPercentage=" + + availabilityPercentage.value() + + ", pending=" + + pendingRequests + + ", errorStamp=" + + errorStamp + + ", stamp=" + + stamp + + ", stamp0=" + + stamp0 + + ", duration=" + + duration + + ", median=" + + median.estimation() + + ", interArrivalTime=" + + interArrivalTime.value() + + ", pendingStreams=" + + pendingStreams + + ", availability=" + + availabilityPercentage.value() + + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java new file mode 100644 index 000000000..528f4f896 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.InterceptorRegistry; + +/** + * A {@link LoadbalanceStrategy} with an interest in configuring the {@link RSocketConnector} for + * connecting to load-balance targets in order to hook into request lifecycle and track usage + * statistics. + * + *

Currently this callback interface is supported for strategies configured in {@link + * LoadbalanceRSocketClient}. + * + * @since 1.1 + */ +public interface ClientLoadbalanceStrategy extends LoadbalanceStrategy { + + /** + * Initialize the connector, for example using the {@link InterceptorRegistry}, to intercept + * requests. + * + * @param connector the connector to configure + */ + void initialize(RSocketConnector connector); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java new file mode 100644 index 000000000..0f87f6510 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java @@ -0,0 +1,71 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +/** + * Compute the exponential weighted moving average of a series of values. The time at which you + * insert the value into `Ewma` is used to compute a weight (recent points are weighted higher). The + * parameter for defining the convergence speed (like most decay process) is the half-life. + * + *

e.g. with a half-life of 10 unit, if you insert 100 at t=0 and 200 at t=10 the ewma will be + * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) + */ +class Ewma { + + final long tau; + + volatile long stamp; + static final AtomicLongFieldUpdater STAMP = + AtomicLongFieldUpdater.newUpdater(Ewma.class, "stamp"); + volatile double ewma; + + public Ewma(long halfLife, TimeUnit unit, double initialValue) { + this.tau = Clock.unit().convert((long) (halfLife / Math.log(2)), unit); + + this.ewma = initialValue; + + STAMP.lazySet(this, 0L); + } + + public synchronized void insert(double x) { + final long now = Clock.now(); + final double elapsed = Math.max(0, now - stamp); + + STAMP.lazySet(this, now); + + double w = Math.exp(-elapsed / tau); + ewma = w * ewma + (1.0 - w) * x; + } + + public synchronized void reset(double value) { + stamp = 0L; + ewma = value; + } + + public double value() { + return ewma; + } + + @Override + public String toString() { + return "Ewma(value=" + ewma + ", age=" + (Clock.now() - stamp) + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java new file mode 100644 index 000000000..6c2b9c3ea --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java @@ -0,0 +1,228 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +abstract class FluxDeferredResolution extends Flux + implements CoreSubscriber, Subscription, BiConsumer, Scannable { + + final ResolvingOperator parent; + final INPUT fluxOrPayload; + final FrameType requestType; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(FluxDeferredResolution.class, "requested"); + + static final long STATE_UNSUBSCRIBED = -1; + static final long STATE_SUBSCRIBER_SET = 0; + static final long STATE_SUBSCRIBED = -2; + static final long STATE_TERMINATED = Long.MIN_VALUE; + + Subscription s; + CoreSubscriber actual; + boolean done; + + FluxDeferredResolution(ResolvingOperator parent, INPUT fluxOrPayload, FrameType requestType) { + this.parent = parent; + this.fluxOrPayload = fluxOrPayload; + this.requestType = requestType; + + REQUESTED.lazySet(this, STATE_UNSUBSCRIBED); + } + + @Override + public final void subscribe(CoreSubscriber actual) { + if (this.requested == STATE_UNSUBSCRIBED + && REQUESTED.compareAndSet(this, STATE_UNSUBSCRIBED, STATE_SUBSCRIBER_SET)) { + + actual.onSubscribe(this); + + if (this.requested == STATE_TERMINATED) { + return; + } + + this.actual = actual; + this.parent.observe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public final Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_TERMINATED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_TERMINATED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_TERMINATED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(Payload payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public final void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_TERMINATED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public final void cancel() { + long state = REQUESTED.getAndSet(this, STATE_TERMINATED); + if (state == STATE_TERMINATED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + if (requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + } + } + + boolean isTerminated() { + return this.requested == STATE_TERMINATED; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java new file mode 100644 index 000000000..cdbdc19b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import java.util.SplittableRandom; + +/** + * Reference: Ma, Qiang, S. Muthukrishnan, and Mark Sandler. "Frugal Streaming for Estimating + * Quantiles." Space-Efficient Data Structures, Streams, and Algorithms. Springer Berlin Heidelberg, + * 2013. 77-96. + * + *

More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ + */ +class FrugalQuantile implements Quantile { + final double increment; + final SplittableRandom rnd; + + int step; + int sign; + double quantile; + + volatile double estimate; + + public FrugalQuantile(double quantile, double increment) { + this.increment = increment; + this.quantile = quantile; + this.estimate = 0.0; + this.step = 1; + this.sign = 0; + this.rnd = new SplittableRandom(System.nanoTime()); + } + + public FrugalQuantile(double quantile) { + this(quantile, 1.0); + } + + public synchronized void reset(double quantile) { + this.quantile = quantile; + this.estimate = 0.0; + this.step = 1; + this.sign = 0; + } + + public double estimation() { + return estimate; + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + } else { + final double v = rnd.nextDouble(); + final double estimate = this.estimate; + + if (x > estimate && v > (1 - quantile)) { + higher(x); + } else if (x < estimate && v > quantile) { + lower(x); + } + } + } + + private void higher(double x) { + double estimate = this.estimate; + + step += sign * increment; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + + this.estimate = estimate; + } + + private void lower(double x) { + double estimate = this.estimate; + + step -= sign * increment; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + + this.estimate = estimate; + } + + @Override + public String toString() { + return "FrugalQuantile(q=" + quantile + ", v=" + estimate + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java new file mode 100644 index 000000000..eebf82fe9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java @@ -0,0 +1,1005 @@ +/* + * Copyright 2014-2020 Real Logic Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.IntToLongFunction; +import reactor.util.annotation.Nullable; + +/** A open addressing with linear probing hash map specialised for primitive key and value pairs. */ +class Int2LongHashMap implements Map, Serializable { + static final float DEFAULT_LOAD_FACTOR = 0.55f; + static final int MIN_CAPACITY = 8; + private static final long serialVersionUID = -690554872053575793L; + + private final float loadFactor; + private final long missingValue; + private int resizeThreshold; + private int size = 0; + private final boolean shouldAvoidAllocation; + + private long[] entries; + private KeySet keySet; + private ValueCollection values; + private EntrySet entrySet; + + /** @param missingValue for the map that represents null. */ + public Int2LongHashMap(final long missingValue) { + this(MIN_CAPACITY, DEFAULT_LOAD_FACTOR, missingValue); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + */ + public Int2LongHashMap( + final int initialCapacity, final float loadFactor, final long missingValue) { + this(initialCapacity, loadFactor, missingValue, true); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + * @param shouldAvoidAllocation should allocation be avoided by caching iterators and map entries. + */ + public Int2LongHashMap( + final int initialCapacity, + final float loadFactor, + final long missingValue, + final boolean shouldAvoidAllocation) { + validateLoadFactor(loadFactor); + + this.loadFactor = loadFactor; + this.missingValue = missingValue; + this.shouldAvoidAllocation = shouldAvoidAllocation; + + capacity(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, initialCapacity))); + } + + /** + * The value to be used as a null marker in the map. + * + * @return value to be used as a null marker in the map. + */ + public long missingValue() { + return missingValue; + } + + /** + * Get the load factor applied for resize operations. + * + * @return the load factor applied for resize operations. + */ + public float loadFactor() { + return loadFactor; + } + + /** + * Get the total capacity for the map to which the load factor will be a fraction of. + * + * @return the total capacity for the map. + */ + public int capacity() { + return entries.length >> 1; + } + + /** + * Get the actual threshold which when reached the map will resize. This is a function of the + * current capacity and load factor. + * + * @return the threshold when the map will resize. + */ + public int resizeThreshold() { + return resizeThreshold; + } + + /** {@inheritDoc} */ + public int size() { + return size; + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return size == 0; + } + + /** + * Get a value using provided key avoiding boxing. + * + * @param key lookup key. + * @return value associated with the key or {@link #missingValue()} if key is not found in the + * map. + */ + public long get(final int key) { + final int mask = entries.length - 1; + int index = evenHash(key, mask); + + long value = missingValue; + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + value = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + return value; + } + + /** + * Put a key value pair in the map. + * + * @param key lookup key + * @param value new value, must not be {@link #missingValue()} + * @return previous value associated with the key, or {@link #missingValue()} if none found + * @throws IllegalArgumentException if value is {@link #missingValue()} + */ + public long put(final int key, final long value) { + if (value == missingValue) { + throw new IllegalArgumentException("cannot accept missingValue"); + } + + final int mask = entries.length - 1; + int index = evenHash(key, mask); + long oldValue = missingValue; + + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + oldValue = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + if (oldValue == missingValue) { + ++size; + entries[index] = key; + } + + entries[index + 1] = value; + + increaseCapacity(); + + return oldValue; + } + + private void increaseCapacity() { + if (size > resizeThreshold) { + // entries.length = 2 * capacity + final int newCapacity = entries.length; + rehash(newCapacity); + } + } + + private void rehash(final int newCapacity) { + final long[] oldEntries = entries; + final int length = entries.length; + + capacity(newCapacity); + + final long[] newEntries = entries; + final int mask = entries.length - 1; + + for (int keyIndex = 0; keyIndex < length; keyIndex += 2) { + final long value = oldEntries[keyIndex + 1]; + if (value != missingValue) { + final int key = (int) oldEntries[keyIndex]; + int index = evenHash(key, mask); + + while (newEntries[index + 1] != missingValue) { + index = next(index, mask); + } + + newEntries[index] = key; + newEntries[index + 1] = value; + } + } + } + + /** + * Int primitive specialised containsKey. + * + * @param key the key to check. + * @return true if the map contains key as a key, false otherwise. + */ + public boolean containsKey(final int key) { + return get(key) != missingValue; + } + + /** + * Does the map contain the value. + * + * @param value to be tested against contained values. + * @return true if contained otherwise value. + */ + public boolean containsValue(final long value) { + boolean found = false; + if (value != missingValue) { + final int length = entries.length; + int remaining = size; + + for (int valueIndex = 1; remaining > 0 && valueIndex < length; valueIndex += 2) { + if (missingValue != entries[valueIndex]) { + if (value == entries[valueIndex]) { + found = true; + break; + } + --remaining; + } + } + } + + return found; + } + + /** {@inheritDoc} */ + public void clear() { + if (size > 0) { + Arrays.fill(entries, missingValue); + size = 0; + } + } + + /** + * Compact the backing arrays by rehashing with a capacity just larger than current size and + * giving consideration to the load factor. + */ + public void compact() { + final int idealCapacity = (int) Math.round(size() * (1.0d / loadFactor)); + rehash(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, idealCapacity))); + } + + /** + * Primitive specialised version of {@link #computeIfAbsent(Object, Function)} + * + * @param key to search on. + * @param mappingFunction to provide a value if the get returns null. + * @return the value if found otherwise the missing value. + */ + public long computeIfAbsent(final int key, final IntToLongFunction mappingFunction) { + long value = get(key); + if (value == missingValue) { + value = mappingFunction.applyAsLong(key); + if (value != missingValue) { + put(key, value); + } + } + + return value; + } + + // ---------------- Boxed Versions Below ---------------- + + /** {@inheritDoc} */ + @Nullable + public Long get(final Object key) { + return valOrNull(get((int) key)); + } + + /** {@inheritDoc} */ + public Long put(final Integer key, final Long value) { + return valOrNull(put((int) key, (long) value)); + } + + /** {@inheritDoc} */ + public boolean containsKey(final Object key) { + return containsKey((int) key); + } + + /** {@inheritDoc} */ + public boolean containsValue(final Object value) { + return containsValue((long) value); + } + + /** {@inheritDoc} */ + public void putAll(final Map map) { + for (final Map.Entry entry : map.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + /** {@inheritDoc} */ + public KeySet keySet() { + if (null == keySet) { + keySet = new KeySet(); + } + + return keySet; + } + + /** {@inheritDoc} */ + public ValueCollection values() { + if (null == values) { + values = new ValueCollection(); + } + + return values; + } + + /** {@inheritDoc} */ + public EntrySet entrySet() { + if (null == entrySet) { + entrySet = new EntrySet(); + } + + return entrySet; + } + + /** {@inheritDoc} */ + @Nullable + public Long remove(final Object key) { + return valOrNull(remove((int) key)); + } + + /** + * Remove value from the map using given key avoiding boxing. + * + * @param key whose mapping is to be removed from the map. + * @return removed value or {@link #missingValue()} if key was not found in the map. + */ + public long remove(final int key) { + final int mask = entries.length - 1; + int keyIndex = evenHash(key, mask); + + long oldValue = missingValue; + while (entries[keyIndex + 1] != missingValue) { + if (entries[keyIndex] == key) { + oldValue = entries[keyIndex + 1]; + entries[keyIndex + 1] = missingValue; + size--; + + compactChain(keyIndex); + + break; + } + + keyIndex = next(keyIndex, mask); + } + + return oldValue; + } + + @SuppressWarnings("FinalParameters") + private void compactChain(int deleteKeyIndex) { + final int mask = entries.length - 1; + int keyIndex = deleteKeyIndex; + + while (true) { + keyIndex = next(keyIndex, mask); + if (entries[keyIndex + 1] == missingValue) { + break; + } + + final int hash = evenHash((int) entries[keyIndex], mask); + + if ((keyIndex < hash && (hash <= deleteKeyIndex || deleteKeyIndex <= keyIndex)) + || (hash <= deleteKeyIndex && deleteKeyIndex <= keyIndex)) { + entries[deleteKeyIndex] = entries[keyIndex]; + entries[deleteKeyIndex + 1] = entries[keyIndex + 1]; + + entries[keyIndex + 1] = missingValue; + deleteKeyIndex = keyIndex; + } + } + } + + /** + * Get the minimum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the minimum value stored in the map. + */ + public long minValue() { + final long missingValue = this.missingValue; + long min = size == 0 ? missingValue : Long.MAX_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + min = Math.min(min, value); + } + } + + return min; + } + + /** + * Get the maximum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the maximum value stored in the map. + */ + public long maxValue() { + final long missingValue = this.missingValue; + long max = size == 0 ? missingValue : Long.MIN_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + max = Math.max(max, value); + } + } + + return max; + } + + /** {@inheritDoc} */ + public String toString() { + if (isEmpty()) { + return "{}"; + } + + final EntryIterator entryIterator = new EntryIterator(); + entryIterator.reset(); + + final StringBuilder sb = new StringBuilder().append('{'); + while (true) { + entryIterator.next(); + sb.append(entryIterator.getIntKey()).append('=').append(entryIterator.getLongValue()); + if (!entryIterator.hasNext()) { + return sb.append('}').toString(); + } + sb.append(',').append(' '); + } + } + + /** + * Primitive specialised version of {@link #replace(Object, Object)} + * + * @param key key with which the specified value is associated + * @param value value to be associated with the specified key + * @return the previous value associated with the specified key, or {@link #missingValue()} if + * there was no mapping for the key. + */ + public long replace(final int key, final long value) { + long currentValue = get(key); + if (currentValue != missingValue) { + currentValue = put(key, value); + } + + return currentValue; + } + + /** + * Primitive specialised version of {@link #replace(Object, Object, Object)} + * + * @param key key with which the specified value is associated + * @param oldValue value expected to be associated with the specified key + * @param newValue value to be associated with the specified key + * @return {@code true} if the value was replaced + */ + public boolean replace(final int key, final long oldValue, final long newValue) { + final long curValue = get(key); + if (curValue != oldValue || curValue == missingValue) { + return false; + } + + put(key, newValue); + + return true; + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Map)) { + return false; + } + + final Map that = (Map) o; + + return size == that.size() && entrySet().equals(that.entrySet()); + } + + public int hashCode() { + return entrySet().hashCode(); + } + + private static int next(final int index, final int mask) { + return (index + 2) & mask; + } + + private void capacity(final int newCapacity) { + final int entriesLength = newCapacity * 2; + if (entriesLength < 0) { + throw new IllegalStateException("max capacity reached at size=" + size); + } + + /*@DoNotSub*/ resizeThreshold = (int) (newCapacity * loadFactor); + entries = new long[entriesLength]; + Arrays.fill(entries, missingValue); + } + + @Nullable + private Long valOrNull(final long value) { + return value == missingValue ? null : value; + } + + // ---------------- Utility Classes ---------------- + + /** Base iterator implementation. */ + abstract class AbstractIterator implements Serializable { + private static final long serialVersionUID = 5262459454112462433L; + /** Is current position valid. */ + protected boolean isPositionValid = false; + + private int remaining; + private int positionCounter; + private int stopCounter; + + final void reset() { + isPositionValid = false; + remaining = Int2LongHashMap.this.size; + final long missingValue = Int2LongHashMap.this.missingValue; + final long[] entries = Int2LongHashMap.this.entries; + final int capacity = entries.length; + + int keyIndex = capacity; + if (entries[capacity - 1] != missingValue) { + for (int i = 1; i < capacity; i += 2) { + if (entries[i] == missingValue) { + keyIndex = i - 1; + break; + } + } + } + + stopCounter = keyIndex; + positionCounter = keyIndex + capacity; + } + + /** + * Returns position of the key of the current entry. + * + * @return key position. + */ + protected final int keyPosition() { + return positionCounter & entries.length - 1; + } + + /** + * Number of remaining elements. + * + * @return number of remaining elements. + */ + public int remaining() { + return remaining; + } + + /** + * Check if there are more elements remaining. + * + * @return {@code true} if {@code remaining > 0}. + */ + public boolean hasNext() { + return remaining > 0; + } + + /** + * Advance to the next entry. + * + * @throws NoSuchElementException if no more entries available. + */ + protected final void findNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final long[] entries = Int2LongHashMap.this.entries; + final long missingValue = Int2LongHashMap.this.missingValue; + final int mask = entries.length - 1; + + for (int keyIndex = positionCounter - 2; keyIndex >= stopCounter; keyIndex -= 2) { + final int index = keyIndex & mask; + if (entries[index + 1] != missingValue) { + isPositionValid = true; + positionCounter = keyIndex; + --remaining; + return; + } + } + + isPositionValid = false; + throw new IllegalStateException(); + } + + /** {@inheritDoc} */ + public void remove() { + if (isPositionValid) { + final int position = keyPosition(); + entries[position + 1] = missingValue; + --size; + + compactChain(position); + + isPositionValid = false; + } else { + throw new IllegalStateException(); + } + } + } + + /** Iterator over keys which supports access to unboxed keys via {@link #nextValue()}. */ + public final class KeyIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = 9151493609653852972L; + + public Integer next() { + return nextValue(); + } + + /** + * Return next key. + * + * @return next key. + */ + public int nextValue() { + findNext(); + return (int) entries[keyPosition()]; + } + } + + /** Iterator over values which supports access to unboxed values. */ + public final class ValueIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = -5670291734793552927L; + + public Long next() { + return nextValue(); + } + + /** + * Return next value. + * + * @return next value. + */ + public long nextValue() { + findNext(); + return entries[keyPosition() + 1]; + } + } + + /** Iterator over entries which supports access to unboxed keys and values. */ + public final class EntryIterator extends AbstractIterator + implements Iterator>, Entry, Serializable { + private static final long serialVersionUID = 1744408438593481051L; + + public Integer getKey() { + return getIntKey(); + } + + /** + * Returns the key of the current entry. + * + * @return the key. + */ + public int getIntKey() { + return (int) entries[keyPosition()]; + } + + public Long getValue() { + return getLongValue(); + } + + /** + * Returns the value of the current entry. + * + * @return the value. + */ + public long getLongValue() { + return entries[keyPosition() + 1]; + } + + public Long setValue(final Long value) { + return setValue(value.longValue()); + } + + /** + * Sets the value of the current entry. + * + * @param value to be set. + * @return previous value of the entry. + */ + public long setValue(final long value) { + if (!isPositionValid) { + throw new IllegalStateException(); + } + + if (missingValue == value) { + throw new IllegalArgumentException(); + } + + final int keyPosition = keyPosition(); + final long prevValue = entries[keyPosition + 1]; + entries[keyPosition + 1] = value; + return prevValue; + } + + public Entry next() { + findNext(); + + if (shouldAvoidAllocation) { + return this; + } + + return allocateDuplicateEntry(); + } + + private Entry allocateDuplicateEntry() { + return new MapEntry(getIntKey(), getLongValue()); + } + + /** {@inheritDoc} */ + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Entry)) { + return false; + } + + final Entry that = (Entry) o; + + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + /** An {@link java.util.Map.Entry} implementation. */ + public final class MapEntry implements Entry { + private final int k; + private final long v; + + /** + * Constructs entry with given key and value. + * + * @param k key. + * @param v value. + */ + public MapEntry(final int k, final long v) { + this.k = k; + this.v = v; + } + + public Integer getKey() { + return k; + } + + public Long getValue() { + return v; + } + + public Long setValue(final Long value) { + return Int2LongHashMap.this.put(k, value.longValue()); + } + + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + public boolean equals(final Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + + final Entry e = (Entry) o; + + return (e.getKey() != null && e.getValue() != null) + && (e.getKey().equals(k) && e.getValue().equals(v)); + } + + public String toString() { + return k + "=" + v; + } + } + } + + /** Set of keys which supports optional cached iterators to avoid allocation. */ + public final class KeySet extends AbstractSet implements Serializable { + private static final long serialVersionUID = -7645453993079742625L; + private final KeyIterator keyIterator = shouldAvoidAllocation ? new KeyIterator() : null; + + /** {@inheritDoc} */ + public KeyIterator iterator() { + KeyIterator keyIterator = this.keyIterator; + if (null == keyIterator) { + keyIterator = new KeyIterator(); + } + + keyIterator.reset(); + + return keyIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((int) o); + } + + /** + * Checks if key is contained in the map without boxing. + * + * @param key to check. + * @return {@code true} if key is contained in this map. + */ + public boolean contains(final int key) { + return containsKey(key); + } + } + + /** Collection of values which supports optionally cached iterators to avoid allocation. */ + public final class ValueCollection extends AbstractCollection implements Serializable { + private static final long serialVersionUID = -8925598924781601919L; + private final ValueIterator valueIterator = shouldAvoidAllocation ? new ValueIterator() : null; + + /** {@inheritDoc} */ + public ValueIterator iterator() { + ValueIterator valueIterator = this.valueIterator; + if (null == valueIterator) { + valueIterator = new ValueIterator(); + } + + valueIterator.reset(); + + return valueIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((long) o); + } + + /** + * Checks if the value is contained in the map. + * + * @param value to be checked. + * @return {@code true} if value is contained in this map. + */ + public boolean contains(final long value) { + return containsValue(value); + } + } + + /** Set of entries which supports optionally cached iterators to avoid allocation. */ + public final class EntrySet extends AbstractSet> + implements Serializable { + private static final long serialVersionUID = 63641283589916174L; + private final EntryIterator entryIterator = shouldAvoidAllocation ? new EntryIterator() : null; + + /** {@inheritDoc} */ + public EntryIterator iterator() { + EntryIterator entryIterator = this.entryIterator; + if (null == entryIterator) { + entryIterator = new EntryIterator(); + } + + entryIterator.reset(); + + return entryIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + if (!(o instanceof Entry)) { + return false; + } + final Entry entry = (Entry) o; + final Long value = get(entry.getKey()); + + return value != null && value.equals(entry.getValue()); + } + + /** {@inheritDoc} */ + public Object[] toArray() { + return toArray(new Object[size()]); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + public T[] toArray(final T[] a) { + final T[] array = + a.length >= size + ? a + : (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size); + final EntryIterator it = iterator(); + + for (int i = 0; i < array.length; i++) { + if (it.hasNext()) { + it.next(); + array[i] = (T) it.allocateDuplicateEntry(); + } else { + array[i] = null; + break; + } + } + + return array; + } + } + + private static int evenHash(final int value, final int mask) { + final int hash = (value << 1) - (value << 8); + + return hash & mask; + } + + private static void validateLoadFactor(final float loadFactor) { + if (loadFactor < 0.1f || loadFactor > 0.9f) { + throw new IllegalArgumentException( + "load factor must be in the range of 0.1 to 0.9: " + loadFactor); + } + } + + private static int findNextPositivePowerOfTwo(final int value) { + return 1 << (Integer.SIZE - Integer.numberOfLeadingZeros(value - 1)); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java new file mode 100644 index 000000000..d59cbb86e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java @@ -0,0 +1,195 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import java.util.List; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * An implementation of {@link RSocketClient} backed by a pool of {@code RSocket} instances and + * using a {@link LoadbalanceStrategy} to select the {@code RSocket} to use for a given request. + * + * @since 1.1 + */ +public class LoadbalanceRSocketClient implements RSocketClient { + + private final RSocketPool rSocketPool; + + private LoadbalanceRSocketClient(RSocketPool rSocketPool) { + this.rSocketPool = rSocketPool; + } + + @Override + public Mono onClose() { + return rSocketPool.onClose(); + } + + @Override + public boolean connect() { + return rSocketPool.connect(); + } + + /** Return {@code Mono} that selects an RSocket from the underlying pool. */ + @Override + public Mono source() { + return Mono.fromSupplier(rSocketPool::select); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().fireAndForget(p)); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().requestResponse(p)); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return payloadMono.flatMapMany(p -> rSocketPool.select().requestStream(p)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return source().flatMapMany(rSocket -> rSocket.requestChannel(payloads)); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().metadataPush(p)); + } + + @Override + public void dispose() { + rSocketPool.dispose(); + } + + /** + * Shortcut to create an {@link LoadbalanceRSocketClient} with round-robin load balancing. + * Effectively a shortcut for: + * + *

+   * LoadbalanceRSocketClient.builder(targetPublisher)
+   *    .connector(RSocketConnector.create())
+   *    .build();
+   * 
+ * + * @param connector a "template" for connecting to load balance targets + * @param targetPublisher refreshes the list of load balance targets periodically + * @return the created client instance + */ + public static LoadbalanceRSocketClient create( + RSocketConnector connector, Publisher> targetPublisher) { + return builder(targetPublisher).connector(connector).build(); + } + + /** + * Return a builder for a {@link LoadbalanceRSocketClient}. + * + * @param targetPublisher refreshes the list of load balance targets periodically + * @return the created builder + */ + public static Builder builder(Publisher> targetPublisher) { + return new Builder(targetPublisher); + } + + /** Builder for creating an {@link LoadbalanceRSocketClient}. */ + public static class Builder { + + private final Publisher> targetPublisher; + + @Nullable private RSocketConnector connector; + + @Nullable LoadbalanceStrategy loadbalanceStrategy; + + Builder(Publisher> targetPublisher) { + this.targetPublisher = targetPublisher; + } + + /** + * Configure the "template" connector to use for connecting to load balance targets. To + * establish a connection, the {@link LoadbalanceTarget#getTransport() ClientTransport} + * contained in each target is passed to the connector's {@link + * RSocketConnector#connect(ClientTransport) connect} method and thus the same connector with + * the same settings applies to all targets. + * + *

By default this is initialized with {@link RSocketConnector#create()}. + * + * @param connector the connector to use as a template + */ + public Builder connector(RSocketConnector connector) { + this.connector = connector; + return this; + } + + /** + * Configure {@link RoundRobinLoadbalanceStrategy} as the strategy to use to select targets. + * + *

This is the strategy used by default. + */ + public Builder roundRobinLoadbalanceStrategy() { + this.loadbalanceStrategy = new RoundRobinLoadbalanceStrategy(); + return this; + } + + /** + * Configure {@link WeightedLoadbalanceStrategy} as the strategy to use to select targets. + * + *

By default, {@link RoundRobinLoadbalanceStrategy} is used. + */ + public Builder weightedLoadbalanceStrategy() { + this.loadbalanceStrategy = WeightedLoadbalanceStrategy.create(); + return this; + } + + /** + * Configure the {@link LoadbalanceStrategy} to use. + * + *

By default, {@link RoundRobinLoadbalanceStrategy} is used. + */ + public Builder loadbalanceStrategy(LoadbalanceStrategy strategy) { + this.loadbalanceStrategy = strategy; + return this; + } + + /** Build the {@link LoadbalanceRSocketClient} instance. */ + public LoadbalanceRSocketClient build() { + final RSocketConnector connector = + (this.connector != null ? this.connector : RSocketConnector.create()); + + final LoadbalanceStrategy strategy = + (this.loadbalanceStrategy != null + ? this.loadbalanceStrategy + : new RoundRobinLoadbalanceStrategy()); + + if (strategy instanceof ClientLoadbalanceStrategy) { + ((ClientLoadbalanceStrategy) strategy).initialize(connector); + } + + return new LoadbalanceRSocketClient( + new RSocketPool(connector, this.targetPublisher, strategy)); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java new file mode 100644 index 000000000..5662448e7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import java.util.List; + +/** + * Strategy to select an {@link RSocket} given a list of instances for load-balancing purposes. A + * simple implementation might go in round-robin fashion while a more sophisticated strategy might + * check availability, track usage stats, and so on. + * + * @since 1.1 + */ +@FunctionalInterface +public interface LoadbalanceStrategy { + + /** + * Select an {@link RSocket} from the given non-empty list. + * + * @param sockets the list to choose from + * @return the selected instance + */ + RSocket select(List sockets); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java new file mode 100644 index 000000000..3b5d71e4e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import org.reactivestreams.Publisher; + +/** + * Representation for a load-balance target used as input to {@link LoadbalanceRSocketClient} that + * in turn maintains and peridodically updates a list of current load-balance targets. The {@link + * #getKey()} is used to identify a target uniquely while the {@link #getTransport() transport} is + * used to connect to the target server. + * + * @since 1.1 + * @see LoadbalanceRSocketClient#create(RSocketConnector, Publisher) + */ +public class LoadbalanceTarget { + + final String key; + final ClientTransport transport; + + private LoadbalanceTarget(String key, ClientTransport transport) { + this.key = key; + this.transport = transport; + } + + /** Return the key that identifies this target uniquely. */ + public String getKey() { + return key; + } + + /** Return the transport to use to connect to the target server. */ + public ClientTransport getTransport() { + return transport; + } + + /** + * Create a new {@link LoadbalanceTarget} with the given key and {@link ClientTransport}. The key + * can be anything that identifies the target uniquely, e.g. SocketAddress, URL, and so on. + * + * @param key identifies the load-balance target uniquely + * @param transport for connecting to the target + * @return the created instance + */ + public static LoadbalanceTarget from(String key, ClientTransport transport) { + return new LoadbalanceTarget(key, transport); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + LoadbalanceTarget that = (LoadbalanceTarget) other; + return key.equals(that.key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java new file mode 100644 index 000000000..5319706f9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java @@ -0,0 +1,99 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +/** This implementation gives better results because it considers more data-point. */ +class Median extends FrugalQuantile { + + public Median() { + super(0.5, 1.0); + } + + public synchronized void reset() { + super.reset(0.5); + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + } else { + final double estimate = this.estimate; + if (x > estimate) { + greaterThanZero(x); + } else if (x < estimate) { + lessThanZero(x); + } + } + } + + private void greaterThanZero(double x) { + double estimate = this.estimate; + + step += sign; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + + this.estimate = estimate; + } + + private void lessThanZero(double x) { + double estimate = this.estimate; + + step -= sign; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + + this.estimate = estimate; + } + + @Override + public String toString() { + return "Median(v=" + estimate + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java new file mode 100644 index 000000000..69838f1b6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java @@ -0,0 +1,226 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +abstract class MonoDeferredResolution extends Mono + implements CoreSubscriber, Subscription, Scannable, BiConsumer { + + final ResolvingOperator parent; + final Payload payload; + final FrameType requestType; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(MonoDeferredResolution.class, "requested"); + + static final long STATE_UNSUBSCRIBED = -1; + static final long STATE_SUBSCRIBER_SET = 0; + static final long STATE_SUBSCRIBED = -2; + static final long STATE_TERMINATED = Long.MIN_VALUE; + + Subscription s; + CoreSubscriber actual; + boolean done; + + MonoDeferredResolution(ResolvingOperator parent, Payload payload, FrameType requestType) { + this.parent = parent; + this.payload = payload; + this.requestType = requestType; + + REQUESTED.lazySet(this, STATE_UNSUBSCRIBED); + } + + @Override + public final void subscribe(CoreSubscriber actual) { + if (this.requested == STATE_UNSUBSCRIBED + && REQUESTED.compareAndSet(this, STATE_UNSUBSCRIBED, STATE_SUBSCRIBER_SET)) { + + actual.onSubscribe(this); + + if (this.requested == STATE_TERMINATED) { + return; + } + + this.actual = actual; + this.parent.observe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_TERMINATED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_TERMINATED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_TERMINATED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(RESULT payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public final void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_TERMINATED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public final void cancel() { + long state = REQUESTED.getAndSet(this, STATE_TERMINATED); + if (state == STATE_TERMINATED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + ReferenceCountUtil.safeRelease(this.payload); + } + } + + boolean isTerminated() { + return this.requested == STATE_TERMINATED; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java new file mode 100644 index 000000000..a77329d31 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java @@ -0,0 +1,310 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.context.Context; + +/** Default implementation of {@link RSocket} stored in {@link RSocketPool} */ +final class PooledRSocket extends ResolvingOperator + implements CoreSubscriber, RSocket { + + final RSocketPool parent; + final Mono rSocketSource; + final LoadbalanceTarget loadbalanceTarget; + final Sinks.Empty onCloseSink; + + volatile Subscription s; + + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(PooledRSocket.class, Subscription.class, "s"); + + PooledRSocket( + RSocketPool parent, Mono rSocketSource, LoadbalanceTarget loadbalanceTarget) { + this.parent = parent; + this.rSocketSource = rSocketSource; + this.loadbalanceTarget = loadbalanceTarget; + this.onCloseSink = Sinks.unsafe().empty(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final RSocket value = this.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + this.doFinally(); + return; + } + + if (value == null) { + this.terminate(new IllegalStateException("Source completed empty")); + } else { + this.complete(value); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + this.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doFinally(); + // terminate upstream (retryBackoff has exhausted) and remove from the parent target list + this.doCleanup(t); + } + + @Override + public void onNext(RSocket value) { + if (this.s == Operators.cancelledSubscription()) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + // volatile write and check on racing + this.doFinally(); + } + + @Override + protected void doSubscribe() { + this.rSocketSource.subscribe(this); + } + + @Override + protected void doOnValueResolved(RSocket value) { + value.onClose().subscribe(null, this::doCleanup, () -> doCleanup(ON_DISPOSE)); + } + + void doCleanup(Throwable t) { + if (isDisposed()) { + return; + } + + this.terminate(t); + + final RSocketPool parent = this.parent; + for (; ; ) { + final PooledRSocket[] sockets = parent.activeSockets; + final int activeSocketsCount = sockets.length; + + int index = -1; + for (int i = 0; i < activeSocketsCount; i++) { + if (sockets[i] == this) { + index = i; + break; + } + } + + if (index == -1) { + break; + } + + final PooledRSocket[] newSockets; + if (activeSocketsCount == 1) { + newSockets = RSocketPool.EMPTY; + } else { + final int lastIndex = activeSocketsCount - 1; + + newSockets = new PooledRSocket[lastIndex]; + if (index != 0) { + System.arraycopy(sockets, 0, newSockets, 0, index); + } + + if (index != lastIndex) { + System.arraycopy(sockets, index + 1, newSockets, index, lastIndex - index); + } + } + + if (RSocketPool.ACTIVE_SOCKETS.compareAndSet(parent, sockets, newSockets)) { + break; + } + } + + if (t == ON_DISPOSE) { + this.onCloseSink.tryEmitEmpty(); + } else { + this.onCloseSink.tryEmitError(t); + } + } + + @Override + protected void doOnValueExpired(RSocket value) { + value.dispose(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onCloseSink::tryEmitError, onCloseSink::tryEmitEmpty); + } else { + onCloseSink.tryEmitEmpty(); + } + } + + @Override + public Mono fireAndForget(Payload payload) { + return new MonoInner<>(this, payload, FrameType.REQUEST_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return new MonoInner<>(this, payload, FrameType.REQUEST_RESPONSE); + } + + @Override + public Flux requestStream(Payload payload) { + return new FluxInner<>(this, payload, FrameType.REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new FluxInner<>(this, payloads, FrameType.REQUEST_CHANNEL); + } + + @Override + public Mono metadataPush(Payload payload) { + return new MonoInner<>(this, payload, FrameType.METADATA_PUSH); + } + + LoadbalanceTarget target() { + return this.loadbalanceTarget; + } + + @Override + public Mono onClose() { + return this.onCloseSink.asMono(); + } + + @Override + public double availability() { + final RSocket socket = valueIfResolved(); + return socket != null ? socket.availability() : 0.0d; + } + + static final class MonoInner extends MonoDeferredResolution { + + MonoInner(PooledRSocket parent, Payload payload, FrameType requestType) { + super(parent, payload, requestType); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(RSocket rSocket, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + ReferenceCountUtil.safeRelease(this.payload); + onError(t); + return; + } + + if (rSocket != null) { + Mono source; + switch (this.requestType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(this.payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(this.payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(this.payload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + } else { + parent.observe(this); + } + } + } + + static final class FluxInner extends FluxDeferredResolution { + + FluxInner(PooledRSocket parent, INPUT fluxOrPayload, FrameType requestType) { + super(parent, fluxOrPayload, requestType); + } + + @Override + @SuppressWarnings("unchecked") + public void accept(RSocket rSocket, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(t); + return; + } + + if (rSocket != null) { + Flux source; + switch (this.requestType) { + case REQUEST_STREAM: + source = rSocket.requestStream((Payload) this.fluxOrPayload); + break; + case REQUEST_CHANNEL: + source = rSocket.requestChannel((Flux) this.fluxOrPayload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + } else { + parent.observe(this); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java new file mode 100644 index 000000000..84c699197 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +interface Quantile { + /** @return the estimation of the current value of the quantile */ + double estimation(); + + /** + * Insert a data point `x` in the quantile estimator. + * + * @param x the data point to add. + */ + void insert(double x); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java new file mode 100644 index 000000000..59d9678d0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java @@ -0,0 +1,532 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.frame.FrameType; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.stream.Collectors; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +class RSocketPool extends ResolvingOperator + implements CoreSubscriber>, Closeable { + + static final AtomicReferenceFieldUpdater ACTIVE_SOCKETS = + AtomicReferenceFieldUpdater.newUpdater( + RSocketPool.class, PooledRSocket[].class, "activeSockets"); + static final PooledRSocket[] EMPTY = new PooledRSocket[0]; + static final PooledRSocket[] TERMINATED = new PooledRSocket[0]; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(RSocketPool.class, Subscription.class, "s"); + final DeferredResolutionRSocket deferredResolutionRSocket = new DeferredResolutionRSocket(this); + final RSocketConnector connector; + final LoadbalanceStrategy loadbalanceStrategy; + final Sinks.Empty onAllClosedSink = Sinks.unsafe().empty(); + volatile PooledRSocket[] activeSockets; + volatile Subscription s; + + public RSocketPool( + RSocketConnector connector, + Publisher> targetPublisher, + LoadbalanceStrategy loadbalanceStrategy) { + this.connector = connector; + this.loadbalanceStrategy = loadbalanceStrategy; + + ACTIVE_SOCKETS.lazySet(this, EMPTY); + + targetPublisher.subscribe(this); + } + + @Override + public Mono onClose() { + return onAllClosedSink.asMono(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + + RSocket[] activeSockets = ACTIVE_SOCKETS.getAndSet(this, TERMINATED); + for (RSocket rSocket : activeSockets) { + rSocket.dispose(); + } + + if (activeSockets.length > 0) { + Mono.whenDelayError( + Arrays.stream(activeSockets).map(RSocket::onClose).collect(Collectors.toList())) + .subscribe(null, onAllClosedSink::tryEmitError, onAllClosedSink::tryEmitEmpty); + } else { + onAllClosedSink.tryEmitEmpty(); + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(List targets) { + if (isDisposed()) { + return; + } + + // This operation should happen less frequently than calls to select() (which are per request) + // and therefore it is acceptable somewhat less efficient. + + PooledRSocket[] previouslyActiveSockets; + PooledRSocket[] inactiveSockets; + PooledRSocket[] socketsToUse; + for (; ; ) { + HashMap rSocketSuppliersCopy = new HashMap<>(targets.size()); + + int j = 0; + for (LoadbalanceTarget target : targets) { + rSocketSuppliersCopy.put(target, j++); + } + + // Intersect current and new list of targets and find the ones to keep vs dispose + previouslyActiveSockets = this.activeSockets; + inactiveSockets = new PooledRSocket[previouslyActiveSockets.length]; + PooledRSocket[] nextActiveSockets = + new PooledRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()]; + int activeSocketsPosition = 0; + int inactiveSocketsPosition = 0; + for (int i = 0; i < previouslyActiveSockets.length; i++) { + PooledRSocket rSocket = previouslyActiveSockets[i]; + + Integer index = rSocketSuppliersCopy.remove(rSocket.target()); + if (index == null) { + // if one of the active rSockets is not included, we remove it and put in the + // pending removal + if (!rSocket.isDisposed()) { + inactiveSockets[inactiveSocketsPosition++] = rSocket; + // TODO: provide a meaningful algo for keeping removed rsocket in the list + // nextActiveSockets[position++] = rSocket; + } + } else { + if (!rSocket.isDisposed()) { + // keep old RSocket instance + nextActiveSockets[activeSocketsPosition++] = rSocket; + } else { + // put newly create RSocket instance + LoadbalanceTarget target = targets.get(index); + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); + } + } + } + + // The remainder are the brand new targets + for (LoadbalanceTarget target : rSocketSuppliersCopy.keySet()) { + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); + } + + if (activeSocketsPosition == 0) { + socketsToUse = EMPTY; + } else { + socketsToUse = Arrays.copyOf(nextActiveSockets, activeSocketsPosition); + } + if (ACTIVE_SOCKETS.compareAndSet(this, previouslyActiveSockets, socketsToUse)) { + break; + } + } + + for (PooledRSocket inactiveSocket : inactiveSockets) { + if (inactiveSocket == null) { + break; + } + + inactiveSocket.dispose(); + } + + if (isPending()) { + // notifies that upstream is resolved + if (socketsToUse != EMPTY) { + //noinspection ConstantConditions + complete(this); + } + } + } + + @Override + public void onError(Throwable t) { + // indicates upstream termination + S.set(this, Operators.cancelledSubscription()); + // propagates error and terminates the whole pool + terminate(t); + } + + @Override + public void onComplete() { + // indicates upstream termination + S.set(this, Operators.cancelledSubscription()); + } + + RSocket select() { + if (isDisposed()) { + return this.deferredResolutionRSocket; + } + + RSocket selected = doSelect(); + + if (selected == null) { + if (this.s == Operators.cancelledSubscription()) { + terminate(new CancellationException("Pool is exhausted")); + } else { + invalidate(); + + // check since it is possible that between doSelect() and invalidate() we might + // have received new sockets + selected = doSelect(); + if (selected != null) { + return selected; + } + } + return this.deferredResolutionRSocket; + } + + return selected; + } + + @Nullable + RSocket doSelect() { + PooledRSocket[] sockets = this.activeSockets; + + if (sockets == EMPTY || sockets == TERMINATED) { + return null; + } + + return this.loadbalanceStrategy.select(WrappingList.wrap(sockets)); + } + + static class DeferredResolutionRSocket implements RSocket { + + final RSocketPool parent; + + DeferredResolutionRSocket(RSocketPool parent) { + this.parent = parent; + } + + @Override + public Mono fireAndForget(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.REQUEST_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.REQUEST_RESPONSE); + } + + @Override + public Flux requestStream(Payload payload) { + return new FluxInner<>(this.parent, payload, FrameType.REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new FluxInner<>(this.parent, payloads, FrameType.REQUEST_CHANNEL); + } + + @Override + public Mono metadataPush(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.METADATA_PUSH); + } + } + + static final class MonoInner extends MonoDeferredResolution { + + MonoInner(RSocketPool parent, Payload payload, FrameType requestType) { + super(parent, payload, requestType); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(Object aVoid, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + ReferenceCountUtil.safeRelease(this.payload); + onError(t); + return; + } + + RSocketPool parent = (RSocketPool) this.parent; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Mono source; + switch (this.requestType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(this.payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(this.payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(this.payload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + + return; + } + + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + ReferenceCountUtil.safeRelease(this.payload); + onError(error); + return; + } + } + } + } + + static final class FluxInner extends FluxDeferredResolution { + + FluxInner(RSocketPool parent, INPUT fluxOrPayload, FrameType requestType) { + super(parent, fluxOrPayload, requestType); + } + + @Override + @SuppressWarnings("unchecked") + public void accept(Object aVoid, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(t); + return; + } + + RSocketPool parent = (RSocketPool) this.parent; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Flux source; + switch (this.requestType) { + case REQUEST_STREAM: + source = rSocket.requestStream((Payload) this.fluxOrPayload); + break; + case REQUEST_CHANNEL: + source = rSocket.requestChannel((Flux) this.fluxOrPayload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + + return; + } + + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(error); + return; + } + } + } + } + + static final class WrappingList implements List { + + static final ThreadLocal INSTANCE = ThreadLocal.withInitial(WrappingList::new); + + private PooledRSocket[] activeSockets; + + static List wrap(PooledRSocket[] activeSockets) { + final WrappingList sockets = INSTANCE.get(); + sockets.activeSockets = activeSockets; + return sockets; + } + + @Override + public RSocket get(int index) { + final PooledRSocket socket = activeSockets[index]; + + RSocket realValue = socket.value; + if (realValue != null) { + return realValue; + } + + realValue = socket.valueIfResolved(); + if (realValue != null) { + return realValue; + } + + return socket; + } + + @Override + public int size() { + return activeSockets.length; + } + + @Override + public boolean isEmpty() { + return activeSockets.length == 0; + } + + @Override + public Object[] toArray() { + return activeSockets; + } + + @Override + @SuppressWarnings("unchecked") + public T[] toArray(T[] a) { + return (T[]) activeSockets; + } + + @Override + public boolean contains(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean add(RSocket weightedRSocket) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean containsAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(int index, Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket set(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public void add(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket remove(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public int indexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public int lastIndexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator() { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public List subList(int fromIndex, int toIndex) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java new file mode 100644 index 000000000..52f16e166 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java @@ -0,0 +1,420 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +// This class is a copy of the same class in io.rsocket.core + +class ResolvingOperator implements Disposable { + + static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ResolvingOperator.class, "wip"); + + volatile BiConsumer[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ResolvingOperator.class, BiConsumer[].class, "subscribers"); + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_UNSUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_SUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] READY = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] TERMINATED = new BiConsumer[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + public ResolvingOperator() { + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public final void dispose() { + this.terminate(ON_DISPOSE); + } + + @Override + public final boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + public final boolean isPending() { + BiConsumer[] state = this.subscribers; + return state != READY && state != TERMINATED; + } + + @Nullable + public final T valueIfResolved() { + if (this.subscribers == READY) { + T value = this.value; + if (value != null) { + return value; + } + } + + return null; + } + + final void observe(BiConsumer actual) { + for (; ; ) { + final int state = this.add(actual); + + T value = this.value; + + if (state == READY_STATE) { + if (value != null) { + actual.accept(value, null); + return; + } + // value == null means racing between invalidate and this subscriber + // thus, we have to loop again + continue; + } else if (state == TERMINATED_STATE) { + actual.accept(null, this.t); + return; + } + + return; + } + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ResolvingOperator} is completed with an error a RuntimeException + * that wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@link ResolvingOperator} or {@code null} if the timeout is reached + * and the {@link ResolvingOperator} has not completed + * @throws RuntimeException if terminated with error + * @throws IllegalStateException if timed out or {@link Thread} was interrupted with {@link + * InterruptedException} + */ + @Nullable + @SuppressWarnings({"uncheked", "BusyWait"}) + public T block(@Nullable Duration timeout) { + try { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + subscribers = this.subscribers; + + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + final void terminate(Throwable t) { + if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // writes happens before volatile write + this.t = t; + + final BiConsumer[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doOnDispose(); + + this.doFinally(); + + for (BiConsumer consumer : subscribers) { + consumer.accept(null, t); + } + } + + final void complete(T value) { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.doOnValueResolved(value); + + for (BiConsumer consumer : subscribers) { + consumer.accept(value, null); + } + } + + protected void doOnValueResolved(T value) { + // no ops + } + + final void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + if (value != null && isDisposed()) { + this.value = null; + this.doOnValueExpired(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + final void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final BiConsumer[] subscribers = this.subscribers; + + if (subscribers == READY) { + // guarded section to ensure we expire value exactly once if there is racing + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final T value = this.value; + if (value != null) { + this.value = null; + this.doOnValueExpired(value); + } + + int m = 1; + for (; ; ) { + if (isDisposed()) { + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + } + + SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED); + } + } + + protected void doOnValueExpired(T value) { + // no ops + } + + protected void doOnDispose() { + // no ops + } + + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + + final int add(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + BiConsumer[] b = new BiConsumer[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.doSubscribe(); + } + return ADDED_STATE; + } + } + } + + protected void doSubscribe() { + // no ops + } + + @SuppressWarnings("unchecked") + final void remove(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + BiConsumer[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new BiConsumer[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java new file mode 100644 index 000000000..f1a9f8c55 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java @@ -0,0 +1,42 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import java.util.List; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Simple {@link LoadbalanceStrategy} that selects the {@code RSocket} to use in round-robin order. + * + * @since 1.1 + */ +public class RoundRobinLoadbalanceStrategy implements LoadbalanceStrategy { + + volatile int nextIndex; + + private static final AtomicIntegerFieldUpdater NEXT_INDEX = + AtomicIntegerFieldUpdater.newUpdater(RoundRobinLoadbalanceStrategy.class, "nextIndex"); + + @Override + public RSocket select(List sockets) { + int length = sockets.size(); + + int indexToUse = Math.abs(NEXT_INDEX.getAndIncrement(this) % length); + + return sockets.get(indexToUse); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java new file mode 100644 index 000000000..c30c8ad6b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java @@ -0,0 +1,249 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.RequestInterceptor; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; +import reactor.util.annotation.Nullable; + +/** + * {@link LoadbalanceStrategy} that assigns a weight to each {@code RSocket} based on {@link + * RSocket#availability() availability} and usage statistics. The weight is used to decide which + * {@code RSocket} to select. + * + *

Use {@link #create()} or a {@link #builder() Builder} to create an instance. + * + * @since 1.1 + * @see Predictive Load-Balancing: Unfair but + * Faster & more Robust + * @see WeightedStatsRequestInterceptor + */ +public class WeightedLoadbalanceStrategy implements ClientLoadbalanceStrategy { + + private static final double EXP_FACTOR = 4.0; + + final int maxPairSelectionAttempts; + final Function weightedStatsResolver; + + private WeightedLoadbalanceStrategy( + int numberOfAttempts, @Nullable Function resolver) { + this.maxPairSelectionAttempts = numberOfAttempts; + this.weightedStatsResolver = (resolver != null ? resolver : new DefaultWeightedStatsResolver()); + } + + @Override + public void initialize(RSocketConnector connector) { + final Function resolver = weightedStatsResolver; + if (resolver instanceof DefaultWeightedStatsResolver) { + ((DefaultWeightedStatsResolver) resolver).init(connector); + } + } + + @Override + public RSocket select(List sockets) { + final int size = sockets.size(); + + RSocket weightedRSocket; + final Function weightedStatsResolver = this.weightedStatsResolver; + switch (size) { + case 1: + weightedRSocket = sockets.get(0); + break; + case 2: + { + RSocket rsc1 = sockets.get(0); + RSocket rsc2 = sockets.get(1); + + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } + break; + default: + { + RSocket rsc1 = null; + RSocket rsc2 = null; + + for (int i = 0; i < this.maxPairSelectionAttempts; i++) { + int i1 = ThreadLocalRandom.current().nextInt(size); + int i2 = ThreadLocalRandom.current().nextInt(size - 1); + + if (i2 >= i1) { + i2++; + } + rsc1 = sockets.get(i1); + rsc2 = sockets.get(i2); + if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { + break; + } + } + + if (rsc1 != null & rsc2 != null) { + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } else if (rsc1 != null) { + weightedRSocket = rsc1; + } else { + weightedRSocket = rsc2; + } + } + } + + return weightedRSocket; + } + + private static double algorithmicWeight( + RSocket rSocket, @Nullable final WeightedStats weightedStats) { + if (weightedStats == null) { + return 1.0; + } + if (rSocket.isDisposed() || rSocket.availability() == 0.0) { + return 0.0; + } + final int pending = weightedStats.pending(); + + double latency = weightedStats.predictedLatency(); + + final double low = weightedStats.lowerQuantileLatency(); + final double high = + Math.max( + weightedStats.higherQuantileLatency(), + low * 1.001); // ensure higherQuantile > lowerQuantile + .1% + final double bandWidth = Math.max(high - low, 1); + + if (latency < low) { + latency /= calculateFactor(low, latency, bandWidth); + } else if (latency > high) { + latency *= calculateFactor(latency, high, bandWidth); + } + + return (rSocket.availability() * weightedStats.weightedAvailability()) + / (1.0d + latency * (pending + 1)); + } + + private static double calculateFactor(final double u, final double l, final double bandWidth) { + final double alpha = (u - l) / bandWidth; + return Math.pow(1 + alpha, EXP_FACTOR); + } + + /** + * Create an instance of {@link WeightedLoadbalanceStrategy} with default settings, which include + * round-robin load-balancing and 5 {@link #maxPairSelectionAttempts}. + */ + public static WeightedLoadbalanceStrategy create() { + return new Builder().build(); + } + + /** Return a builder to create a {@link WeightedLoadbalanceStrategy} with. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link WeightedLoadbalanceStrategy}. */ + public static class Builder { + + private int maxPairSelectionAttempts = 5; + + @Nullable private Function weightedStatsResolver; + + private Builder() {} + + /** + * How many times to try to randomly select a pair of RSocket connections with non-zero + * availability. This is applicable when there are more than two connections in the pool. If the + * number of attempts is exceeded, the last selected pair is used. + * + *

By default this is set to 5. + * + * @param numberOfAttempts the iteration count + */ + public Builder maxPairSelectionAttempts(int numberOfAttempts) { + this.maxPairSelectionAttempts = numberOfAttempts; + return this; + } + + /** + * Configure how the created {@link WeightedLoadbalanceStrategy} should find the stats for a + * given RSocket. + * + *

By default this resolver is not set. + * + *

When {@code WeightedLoadbalanceStrategy} is used through the {@link + * LoadbalanceRSocketClient}, the resolver does not need to be set because a {@link + * WeightedStatsRequestInterceptor} is automatically installed through the {@link + * ClientLoadbalanceStrategy} callback. If this strategy is used in any other context however, a + * resolver here must be provided. + * + * @param resolver to find the stats for an RSocket with + */ + public Builder weightedStatsResolver(Function resolver) { + this.weightedStatsResolver = resolver; + return this; + } + + /** Build the {@code WeightedLoadbalanceStrategy} instance. */ + public WeightedLoadbalanceStrategy build() { + return new WeightedLoadbalanceStrategy( + this.maxPairSelectionAttempts, this.weightedStatsResolver); + } + } + + private static class DefaultWeightedStatsResolver implements Function { + + final Map statsMap = new ConcurrentHashMap<>(); + + @Override + public WeightedStats apply(RSocket rSocket) { + return statsMap.get(rSocket); + } + + void init(RSocketConnector connector) { + connector.interceptors( + registry -> + registry.forRequestsInRequester( + (Function) + rSocket -> { + final WeightedStatsRequestInterceptor interceptor = + new WeightedStatsRequestInterceptor() { + @Override + public void dispose() { + statsMap.remove(rSocket); + } + }; + statsMap.put(rSocket, interceptor); + + return interceptor; + })); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java new file mode 100644 index 000000000..5ebe668ce --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; + +/** + * Contract to expose the stats required in {@link WeightedLoadbalanceStrategy} to calculate an + * algorithmic weight for an {@code RSocket}. The weight helps to select an {@code RSocket} for + * load-balancing. + * + * @since 1.1 + */ +public interface WeightedStats { + + double higherQuantileLatency(); + + double lowerQuantileLatency(); + + int pending(); + + double predictedLatency(); + + double weightedAvailability(); + + /** + * Create a proxy for the given {@code RSocket} that attaches the stats contained in this instance + * and exposes them as {@link WeightedStats}. + * + * @param rsocket the RSocket to wrap + * @return the wrapped RSocket + * @since 1.1.1 + */ + default RSocket wrap(RSocket rsocket) { + return new WeightedStatsRSocketProxy(rsocket, this); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java new file mode 100644 index 000000000..f2cf3fbd0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java @@ -0,0 +1,62 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; + +/** + * Package private {@code RSocketProxy} used from {@link WeightedStats#wrap(RSocket)} to attach a + * {@link WeightedStats} instance to an {@code RSocket}. + */ +class WeightedStatsRSocketProxy extends RSocketProxy implements WeightedStats { + + private final WeightedStats weightedStats; + + public WeightedStatsRSocketProxy(RSocket source, WeightedStats weightedStats) { + super(source); + this.weightedStats = weightedStats; + } + + @Override + public double higherQuantileLatency() { + return this.weightedStats.higherQuantileLatency(); + } + + @Override + public double lowerQuantileLatency() { + return this.weightedStats.lowerQuantileLatency(); + } + + @Override + public int pending() { + return this.weightedStats.pending(); + } + + @Override + public double predictedLatency() { + return this.weightedStats.predictedLatency(); + } + + @Override + public double weightedAvailability() { + return this.weightedStats.weightedAvailability(); + } + + public WeightedStats getDelegate() { + return this.weightedStats; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java new file mode 100644 index 000000000..ec2c88b19 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import reactor.util.annotation.Nullable; + +/** + * {@link RequestInterceptor} that hooks into request lifecycle and calls methods of the parent + * class to manage tracking state and expose {@link WeightedStats}. + * + *

This interceptor the default mechanism for gathering stats when {@link + * WeightedLoadbalanceStrategy} is used with {@link LoadbalanceRSocketClient}. + * + * @since 1.1 + * @see LoadbalanceRSocketClient + * @see WeightedLoadbalanceStrategy + */ +public class WeightedStatsRequestInterceptor extends BaseWeightedStats + implements RequestInterceptor { + + final Int2LongHashMap requestsStartTime = new Int2LongHashMap(-1); + + public WeightedStatsRequestInterceptor() { + super(); + } + + @Override + public final void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + final long startTime = startRequest(); + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + requestsStartTime.put(streamId, startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + this.startStream(); + } + } + + @Override + public final void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + long endTime = stopRequest(startTime); + if (t == null) { + record(endTime - startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + + if (t != null) { + updateAvailability(0.0d); + } else { + updateAvailability(1.0d); + } + } + + @Override + public final void onCancel(int streamId, FrameType requestType) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + stopRequest(startTime); + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + } + + @Override + public final void onReject(Throwable rejectionReason, FrameType requestType, ByteBuf metadata) {} + + @Override + public void dispose() {} +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java new file mode 100644 index 000000000..19668e99c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Support client load-balancing in RSocket Java. */ +@NonNullApi +package io.rsocket.loadbalance; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java new file mode 100644 index 000000000..c16c4dc52 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java @@ -0,0 +1,334 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.util.CharByteBufUtil; + +public class AuthMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + static final int USERNAME_BYTES_LENGTH = 2; + static final int AUTH_TYPE_ID_LENGTH = 1; + + static final char[] EMPTY_CHARS_ARRAY = new char[0]; + + private AuthMetadataCodec() {} + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customAuthType the custom mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or + * empty string or its length is greater than 128 bytes + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { + + int actualASCIILength = ByteBufUtil.utf8Bytes(customAuthType); + if (actualASCIILength != customAuthType.length()) { + throw new IllegalArgumentException("custom auth type must be US_ASCII characters only"); + } + if (actualASCIILength < 1 || actualASCIILength > 128) { + throw new IllegalArgumentException( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + int capacity = 1 + actualASCIILength; + ByteBuf headerBuffer = allocator.buffer(capacity, capacity); + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + headerBuffer.writeByte(actualASCIILength - 1); + + ByteBufUtil.reserveAndWriteUtf8(headerBuffer, customAuthType, actualASCIILength); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. + * @param authType the well-known mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code authType} is {@link + * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link + * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { + + if (authType == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE + || authType == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE) { + throw new IllegalArgumentException("only allowed AuthType should be used"); + } + + int capacity = AUTH_TYPE_ID_LENGTH; + ByteBuf headerBuffer = + allocator + .buffer(capacity, capacity) + .writeByte(authType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using Simple Authentication format + * + * @throws IllegalArgumentException if the username length is greater than 65535 + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param username the char sequence which represents user name. + * @param password the char sequence which represents user password. + */ + public static ByteBuf encodeSimpleMetadata( + ByteBufAllocator allocator, char[] username, char[] password) { + + int usernameLength = CharByteBufUtil.utf8Bytes(username); + if (usernameLength > 65535) { + throw new IllegalArgumentException( + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); + } + + int passwordLength = CharByteBufUtil.utf8Bytes(password); + int capacity = AUTH_TYPE_ID_LENGTH + USERNAME_BYTES_LENGTH + usernameLength + passwordLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.SIMPLE.getIdentifier() | STREAM_METADATA_KNOWN_MASK) + .writeShort(usernameLength); + + CharByteBufUtil.writeUtf8(buffer, username); + CharByteBufUtil.writeUtf8(buffer, password); + + return buffer; + } + + /** + * Encode a Authentication CompositeMetadata payload using Bearer Authentication format + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param token the char sequence which represents BEARER token. + */ + public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { + + int tokenLength = CharByteBufUtil.utf8Bytes(token); + int capacity = AUTH_TYPE_ID_LENGTH + tokenLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.BEARER.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + CharByteBufUtil.writeUtf8(buffer, token); + + return buffer; + } + + /** + * Encode a new Authentication Metadata payload information, first verifying if the passed {@link + * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed + * fashion using the mime id of that type). + * + *

Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know + * that the mime type is not a {@link WellKnownAuthType}. + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param authType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) + * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) + */ + public static ByteBuf encodeMetadataWithCompression( + ByteBufAllocator allocator, String authType, ByteBuf metadata) { + WellKnownAuthType wkn = WellKnownAuthType.fromString(authType); + if (wkn == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE) { + return AuthMetadataCodec.encodeMetadata(allocator, authType, metadata); + } else { + return AuthMetadataCodec.encodeMetadata(allocator, wkn, metadata); + } + } + + /** + * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link + * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} + * + * @param metadata byteBuf used to get information from + */ + public static boolean isWellKnownAuthType(ByteBuf metadata) { + byte lengthOrId = metadata.getByte(0); + return (lengthOrId & STREAM_METADATA_LENGTH_MASK) != lengthOrId; + } + + /** + * Read first byte from the given {@code metadata} and tries to convert it's value to {@link + * WellKnownAuthType}. + * + * @param metadata given metadata buffer to read from + * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if + * field's value is length or unknown auth type + * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} + */ + public static WellKnownAuthType readWellKnownAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode Well Know Auth type. Not enough readable bytes"); + } + byte lengthOrId = metadata.readByte(); + int normalizedId = (byte) (lengthOrId & STREAM_METADATA_LENGTH_MASK); + + if (normalizedId != lengthOrId) { + return WellKnownAuthType.fromIdentifier(normalizedId); + } + + return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } + + /** + * Read up to 129 bytes from the given metadata in order to get the custom Auth Type + * + * @param metadata + * @return + */ + public static CharSequence readCustomAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 2) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Not enough readable bytes"); + } + + byte encodedLength = metadata.readByte(); + if (encodedLength < 0) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Incorrect auth type length"); + } + + // encoded length is realLength - 1 in order to avoid intersection with 0x00 authtype + int realLength = encodedLength + 1; + if (metadata.readableBytes() < realLength) { + throw new IllegalArgumentException( + "Unable to decode custom Auth type. Malformed length or auth type string"); + } + + return metadata.readCharSequence(realLength, CharsetUtil.US_ASCII); + } + + /** + * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced + * representation of a payload + * + * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} + * should be set to the beginning of the payload bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the + * given one + */ + public static ByteBuf readPayload(ByteBuf metadata) { + if (metadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return metadata.readSlice(metadata.readableBytes()); + } + + /** + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length position + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero + */ + public static ByteBuf readUsername(ByteBuf simpleAuthMetadata) { + int usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero + */ + public static ByteBuf readPassword(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(simpleAuthMetadata.readableBytes()); + } + /** + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return {@code char[]} which represents UTF-8 username + */ + public static char[] readUsernameAsCharArray(ByteBuf simpleAuthMetadata) { + int usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readPasswordAsCharArray(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, simpleAuthMetadata.readableBytes()); + } + + /** + * Read all the remaining {@code bytes} from the given {@link ByteBuf} + * + * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code + * bearerAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { + if (bearerAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(bearerAuthMetadata, bearerAuthMetadata.readableBytes()); + } + + private static int readUsernameLength(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() < 2) { + throw new IllegalStateException( + "Unable to decode custom username. Not enough readable bytes"); + } + + int usernameLength = simpleAuthMetadata.readUnsignedShort(); + + if (simpleAuthMetadata.readableBytes() < usernameLength) { + throw new IllegalArgumentException( + "Unable to decode username. Malformed username length or content"); + } + + return usernameLength; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java new file mode 100644 index 000000000..1c3ae9423 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java @@ -0,0 +1,241 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataCodec.computeNextEntryIndex; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.hasEntry; +import static io.rsocket.metadata.CompositeMetadataCodec.isWellKnownMimeType; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadata.Entry; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import reactor.util.annotation.Nullable; + +/** + * An {@link Iterable} wrapper around a {@link ByteBuf} that exposes metadata entry information at + * each decoding step. This is only possible on frame types used to initiate interactions, if the + * SETUP metadata mime type was {@link WellKnownMimeType#MESSAGE_RSOCKET_COMPOSITE_METADATA}. + * + *

This allows efficient incremental decoding of the entries (without moving the source's {@link + * io.netty.buffer.ByteBuf#readerIndex()}). The buffer is assumed to contain just enough bytes to + * represent one or more entries (mime type compressed or not). The decoding stops when the buffer + * reaches 0 readable bytes, and fails if it contains bytes but not enough to correctly decode an + * entry. + * + *

A note on future-proofness: it is possible to come across a compressed mime type that this + * implementation doesn't recognize. This is likely to be due to the use of a byte id that is merely + * reserved in this implementation, but maps to a {@link WellKnownMimeType} in the implementation + * that encoded the metadata. This can be detected by detecting that an entry is a {@link + * ReservedMimeTypeEntry}. In this case {@link Entry#getMimeType()} will return {@code null}. The + * encoded id can be retrieved using {@link ReservedMimeTypeEntry#getType()}. The byte and content + * buffer should be kept around and re-encoded using {@link + * CompositeMetadataCodec#encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, byte, ByteBuf)} + * in case passing that entry through is required. + */ +public final class CompositeMetadata implements Iterable { + + private final boolean retainSlices; + + private final ByteBuf source; + + public CompositeMetadata(ByteBuf source, boolean retainSlices) { + this.source = source; + this.retainSlices = retainSlices; + } + + /** + * Turn this {@link CompositeMetadata} into a sequential {@link Stream}. + * + * @return the composite metadata sequential {@link Stream} + */ + public Stream stream() { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + iterator(), Spliterator.DISTINCT | Spliterator.NONNULL | Spliterator.ORDERED), + false); + } + + /** + * An {@link Iterator} that lazily decodes {@link Entry} in this composite metadata. + * + * @return the composite metadata {@link Iterator} + */ + @Override + public Iterator iterator() { + return new Iterator() { + + private int entryIndex = 0; + + @Override + public boolean hasNext() { + return hasEntry(CompositeMetadata.this.source, this.entryIndex); + } + + @Override + public Entry next() { + ByteBuf[] headerAndData = + decodeMimeAndContentBuffersSlices( + CompositeMetadata.this.source, + this.entryIndex, + CompositeMetadata.this.retainSlices); + + ByteBuf header = headerAndData[0]; + ByteBuf data = headerAndData[1]; + + this.entryIndex = computeNextEntryIndex(this.entryIndex, header, data); + + if (!isWellKnownMimeType(header)) { + CharSequence typeString = decodeMimeTypeFromMimeBuffer(header); + if (typeString == null) { + throw new IllegalStateException("MIME type cannot be null"); + } + + return new ExplicitMimeTimeEntry(data, typeString.toString()); + } + + byte id = decodeMimeIdFromMimeBuffer(header); + WellKnownMimeType type = WellKnownMimeType.fromIdentifier(id); + + if (WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE == type) { + return new ReservedMimeTypeEntry(data, id); + } + + return new WellKnownMimeTypeEntry(data, type); + } + }; + } + + /** An entry in the {@link CompositeMetadata}. */ + public interface Entry { + + /** + * Returns the un-decoded content of the {@link Entry}. + * + * @return the un-decoded content of the {@link Entry} + */ + ByteBuf getContent(); + + /** + * Returns the MIME type of the entry, if it can be decoded. + * + * @return the MIME type of the entry, if it can be decoded, otherwise {@code null}. + */ + @Nullable + String getMimeType(); + } + + /** An {@link Entry} backed by an explicitly declared MIME type. */ + public static final class ExplicitMimeTimeEntry implements Entry { + + private final ByteBuf content; + + private final String type; + + public ExplicitMimeTimeEntry(ByteBuf content, String type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type; + } + } + + /** + * An {@link Entry} backed by a {@link WellKnownMimeType} entry, but one that is not understood by + * this implementation. + */ + public static final class ReservedMimeTypeEntry implements Entry { + private final ByteBuf content; + private final int type; + + public ReservedMimeTypeEntry(ByteBuf content, int type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + /** + * {@inheritDoc} Since this entry represents a compressed id that couldn't be decoded, this is + * always {@code null}. + */ + @Override + public String getMimeType() { + return null; + } + + /** + * Returns the reserved, but unknown {@link WellKnownMimeType} for this entry. Range is 0-127 + * (inclusive). + * + * @return the reserved, but unknown {@link WellKnownMimeType} for this entry + */ + public int getType() { + return this.type; + } + } + + /** An {@link Entry} backed by a {@link WellKnownMimeType}. */ + public static final class WellKnownMimeTypeEntry implements Entry { + + private final ByteBuf content; + private final WellKnownMimeType type; + + public WellKnownMimeTypeEntry(ByteBuf content, WellKnownMimeType type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type.getString(); + } + + /** + * Returns the {@link WellKnownMimeType} for this entry. + * + * @return the {@link WellKnownMimeType} for this entry + */ + public WellKnownMimeType getType() { + return this.type; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java new file mode 100644 index 000000000..5e00abba8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java @@ -0,0 +1,385 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.util.NumberUtils; +import reactor.util.annotation.Nullable; + +/** + * A flyweight class that can be used to encode/decode composite metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * CompositeMetadata} for an Iterator-like approach to decoding entries. + */ +public class CompositeMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private CompositeMetadataCodec() {} + + public static int computeNextEntryIndex( + int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { + return currentEntryIndex + + headerSlice.readableBytes() // this includes the mime length byte + + 3 // 3 bytes of the content length, which are excluded from the slice + + contentSlice.readableBytes(); + } + + /** + * Decode the next metadata entry (a mime header + content pair of {@link ByteBuf}) from a {@link + * ByteBuf} that contains at least enough bytes for one more such entry. These buffers are + * actually slices of the full metadata buffer, and this method doesn't move the full metadata + * buffer's {@link ByteBuf#readerIndex()}. As such, it requires the user to provide an {@code + * index} to read from. The next index is computed by calling {@link #computeNextEntryIndex(int, + * ByteBuf, ByteBuf)}. Size of the first buffer (the "header buffer") drives which decoding method + * should be further applied to it. + * + *

The header buffer is either: + * + *

    + *
  • made up of a single byte: this represents an encoded mime id, which can be further + * decoded using {@link #decodeMimeIdFromMimeBuffer(ByteBuf)} + *
  • made up of 2 or more bytes: this represents an encoded mime String + its length, which + * can be further decoded using {@link #decodeMimeTypeFromMimeBuffer(ByteBuf)}. Note the + * encoded length, in the first byte, is skipped by this decoding method because the + * remaining length of the buffer is that of the mime string. + *
+ * + * @param compositeMetadata the source {@link ByteBuf} that originally contains one or more + * metadata entries + * @param entryIndex the {@link ByteBuf#readerIndex()} to start decoding from. original reader + * index is kept on the source buffer + * @param retainSlices should produced metadata entry buffers {@link ByteBuf#slice() slices} be + * {@link ByteBuf#retainedSlice() retained}? + * @return a {@link ByteBuf} array of length 2 containing the mime header buffer + * slice and the content buffer slice, or one of the + * zero-length error constant arrays + */ + public static ByteBuf[] decodeMimeAndContentBuffersSlices( + ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { + compositeMetadata.markReaderIndex(); + compositeMetadata.readerIndex(entryIndex); + + if (compositeMetadata.isReadable()) { + ByteBuf mime; + int ridx = compositeMetadata.readerIndex(); + byte mimeIdOrLength = compositeMetadata.readByte(); + if ((mimeIdOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + mime = + retainSlices + ? compositeMetadata.retainedSlice(ridx, 1) + : compositeMetadata.slice(ridx, 1); + } else { + // M flag unset, remaining 7 bits are the length of the mime + int mimeLength = Byte.toUnsignedInt(mimeIdOrLength) + 1; + + if (compositeMetadata.isReadable( + mimeLength)) { // need to be able to read an extra mimeLength bytes + // here we need a way for the returned ByteBuf to differentiate between a + // 1-byte length mime type and a 1 byte encoded mime id, preferably without + // re-applying the byte mask. The easiest way is to include the initial byte + // and have further decoding ignore the first byte. 1 byte buffer == id, 2+ byte + // buffer == full mime string. + mime = + retainSlices + ? + // we accommodate that we don't read from current readerIndex, but + // readerIndex - 1 ("0"), for a total slice size of mimeLength + 1 + compositeMetadata.retainedSlice(ridx, mimeLength + 1) + : compositeMetadata.slice(ridx, mimeLength + 1); + // we thus need to skip the bytes we just sliced, but not the flag/length byte + // which was already skipped in initial read + compositeMetadata.skipBytes(mimeLength); + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + + if (compositeMetadata.isReadable(3)) { + // ensures the length medium can be read + final int metadataLength = compositeMetadata.readUnsignedMedium(); + if (compositeMetadata.isReadable(metadataLength)) { + ByteBuf metadata = + retainSlices + ? compositeMetadata.readRetainedSlice(metadataLength) + : compositeMetadata.readSlice(metadataLength); + compositeMetadata.resetReaderIndex(); + return new ByteBuf[] {mime, metadata}; + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + compositeMetadata.resetReaderIndex(); + throw new IllegalArgumentException( + String.format("entry index %d is larger than buffer size", entryIndex)); + } + + /** + * Decode a {@code byte} compressed mime id from a {@link ByteBuf}, assuming said buffer properly + * contains such an id. + * + *

The buffer must have exactly one readable byte, which is assumed to have been tested for + * mime id encoding via the {@link #STREAM_METADATA_KNOWN_MASK} mask ({@code firstByte & + * STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). + * + *

If there is no readable byte, the negative identifier of {@link + * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeBuffer the buffer that should next contain the compressed mime id byte + * @return the compressed mime id, between 0 and 127, or a negative id if the input is invalid + * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) + */ + public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { + if (mimeBuffer.readableBytes() != 1) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier(); + } + return (byte) (mimeBuffer.readByte() & STREAM_METADATA_LENGTH_MASK); + } + + /** + * Decode a {@link CharSequence} custome mime type from a {@link ByteBuf}, assuming said buffer + * properly contains such a mime type. + * + *

The buffer must at least have two readable bytes, which distinguishes it from the {@link + * #decodeMimeIdFromMimeBuffer(ByteBuf) compressed id} case. The first byte is a size and the + * remaining bytes must correspond to the {@link CharSequence}, encoded fully in US_ASCII. As a + * result, the first byte can simply be skipped, and the remaining of the buffer be decoded to the + * mime type. + * + *

If the mime header buffer is less than 2 bytes long, returns {@code null}. + * + * @param flyweightMimeBuffer the mime header {@link ByteBuf} that contains length + custom mime + * type + * @return the decoded custom mime type, as a {@link CharSequence}, or null if the input is + * invalid + * @see #decodeMimeIdFromMimeBuffer(ByteBuf) + */ + @Nullable + public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { + if (flyweightMimeBuffer.readableBytes() < 2) { + throw new IllegalStateException("unable to decode explicit MIME type"); + } + // the encoded length is assumed to be kept at the start of the buffer + // but also assumed to be irrelevant because the rest of the slice length + // actually already matches _decoded_length + flyweightMimeBuffer.skipBytes(1); + int mimeStringLength = flyweightMimeBuffer.readableBytes(); + return flyweightMimeBuffer.readCharSequence(mimeStringLength, CharsetUtil.US_ASCII); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, without checking if the {@link String} can be matched with a well known compressable + * mime type. Prefer using this method and {@link #encodeAndAddMetadata(CompositeByteBuf, + * ByteBufAllocator, WellKnownMimeType, ByteBuf)} if you know in advance whether or not the mime + * is well known. Otherwise use {@link #encodeAndAddMetadataWithCompression(CompositeByteBuf, + * ByteBufAllocator, String, ByteBuf)} + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customMimeType the custom mime type to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String customMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, customMimeType, metadata.readableBytes()), metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + WellKnownMimeType knownMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, knownMimeType.getIdentifier(), metadata.readableBytes()), + metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, first verifying if the passed {@link String} matches a {@link WellKnownMimeType} (in + * which case it will be encoded in a compressed fashion using the mime id of that type). + * + *

Prefer using {@link #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, String, + * ByteBuf)} if you already know that the mime type is not a {@link WellKnownMimeType}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param mimeType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, WellKnownMimeType, ByteBuf) + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadataWithCompression( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String mimeType, + ByteBuf metadata) { + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, mimeType, metadata.readableBytes()), metadata); + } else { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, wkn.getIdentifier(), metadata.readableBytes()), + metadata); + } + } + + /** + * Returns whether there is another entry available at a given index + * + * @param compositeMetadata the buffer to inspect + * @param entryIndex the index to check at + * @return whether there is another entry available at a given index + */ + public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { + return compositeMetadata.writerIndex() - entryIndex > 0; + } + + /** + * Returns whether the header represents a well-known MIME type. + * + * @param header the header to inspect + * @return whether the header represents a well-known MIME type + */ + public static boolean isWellKnownMimeType(ByteBuf header) { + return header.readableBytes() == 1; + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param unknownCompressedMimeType the id of the {@link + * WellKnownMimeType#UNKNOWN_RESERVED_MIME_TYPE} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + byte unknownCompressedMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, unknownCompressedMimeType, metadata.readableBytes()), + metadata); + } + + /** + * Encode a custom mime type and a metadata value length into a newly allocated {@link ByteBuf}. + * + *

This larger representation encodes the mime type representation's length on a single byte, + * then the representation itself, then the unsigned metadata value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param customMime a custom mime type to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, String customMime, int metadataLength) { + ByteBuf metadataHeader = allocator.buffer(4 + customMime.length()); + // reserve 1 byte for the customMime length + // /!\ careful not to read that first byte, which is random at this point + int writerIndexInitial = metadataHeader.writerIndex(); + metadataHeader.writerIndex(writerIndexInitial + 1); + + // write the custom mime in UTF8 but validate it is all ASCII-compatible + // (which produces the right result since ASCII chars are still encoded on 1 byte in UTF8) + int customMimeLength = ByteBufUtil.writeUtf8(metadataHeader, customMime); + if (!ByteBufUtil.isText( + metadataHeader, metadataHeader.readerIndex() + 1, customMimeLength, CharsetUtil.US_ASCII)) { + metadataHeader.release(); + throw new IllegalArgumentException("custom mime type must be US_ASCII characters only"); + } + if (customMimeLength < 1 || customMimeLength > 128) { + metadataHeader.release(); + throw new IllegalArgumentException( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + metadataHeader.markWriterIndex(); + + // go back to beginning and write the length + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + metadataHeader.writerIndex(writerIndexInitial); + metadataHeader.writeByte(customMimeLength - 1); + + // go back to post-mime type and write the metadata content length + metadataHeader.resetWriterIndex(); + NumberUtils.encodeUnsignedMedium(metadataHeader, metadataLength); + + return metadataHeader; + } + + /** + * Encode a {@link WellKnownMimeType well known mime type} and a metadata value length into a + * newly allocated {@link ByteBuf}. + * + *

This compact representation encodes the mime type via its ID on a single byte, and the + * unsigned value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param mimeType a byte identifier of a {@link WellKnownMimeType} to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, byte mimeType, int metadataLength) { + ByteBuf buffer = allocator.buffer(4, 4).writeByte(mimeType | STREAM_METADATA_KNOWN_MASK); + + NumberUtils.encodeUnsignedMedium(buffer, metadataLength); + + return buffer; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java new file mode 100644 index 000000000..2e03bd754 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java @@ -0,0 +1,137 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayList; +import java.util.List; + +/** + * Provides support for encoding and decoding the per-stream MIME type to use for payload data. + * + *

For more on the format of the metadata, see the + * Stream Data MIME Types extension specification. + * + * @since 1.1.1 + */ +public class MimeTypeMetadataCodec { + + private static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + private static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private MimeTypeMetadataCodec() {} + + /** + * Encode a {@link WellKnownMimeType} into a newly allocated single byte {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType well-known MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, WellKnownMimeType mimeType) { + return allocator.buffer(1, 1).writeByte(mimeType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + } + + /** + * Encode the given MIME type into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, String mimeType) { + if (mimeType == null || mimeType.length() == 0) { + throw new IllegalArgumentException("MIME type is required"); + } + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + return encodeCustomMimeType(allocator, mimeType); + } else { + return encode(allocator, wkn); + } + } + + /** + * Encode multiple MIME types into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeTypes MIME types to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, List mimeTypes) { + if (mimeTypes == null || mimeTypes.size() == 0) { + throw new IllegalArgumentException("No MIME types provided"); + } + CompositeByteBuf compositeByteBuf = allocator.compositeBuffer(); + for (String mimeType : mimeTypes) { + ByteBuf byteBuf = encode(allocator, mimeType); + compositeByteBuf.addComponents(true, byteBuf); + } + return compositeByteBuf; + } + + private static ByteBuf encodeCustomMimeType(ByteBufAllocator allocator, String customMimeType) { + ByteBuf byteBuf = allocator.buffer(1 + customMimeType.length()); + + byteBuf.writerIndex(1); + int length = ByteBufUtil.writeUtf8(byteBuf, customMimeType); + + if (!ByteBufUtil.isText(byteBuf, 1, length, CharsetUtil.US_ASCII)) { + byteBuf.release(); + throw new IllegalArgumentException("MIME type must be ASCII characters only"); + } + + if (length < 1 || length > 128) { + byteBuf.release(); + throw new IllegalArgumentException( + "MIME type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + byteBuf.markWriterIndex(); + byteBuf.writerIndex(0); + byteBuf.writeByte(length - 1); + byteBuf.resetWriterIndex(); + + return byteBuf; + } + + /** + * Decode the per-stream MIME type metadata encoded in the given {@link ByteBuf}. + * + * @return the decoded MIME types + */ + public static List decode(ByteBuf byteBuf) { + List mimeTypes = new ArrayList<>(); + while (byteBuf.isReadable()) { + byte idOrLength = byteBuf.readByte(); + if ((idOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + byte id = (byte) (idOrLength & STREAM_METADATA_LENGTH_MASK); + WellKnownMimeType wellKnownMimeType = WellKnownMimeType.fromIdentifier(id); + mimeTypes.add(wellKnownMimeType.toString()); + } else { + int length = Byte.toUnsignedInt(idOrLength) + 1; + mimeTypes.add(byteBuf.readCharSequence(length, CharsetUtil.US_ASCII).toString()); + } + } + return mimeTypes; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java new file mode 100644 index 000000000..d1f2643dc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java @@ -0,0 +1,18 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; + +/** + * Routing Metadata extension from + * https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + * + * @author linux_china + */ +public class RoutingMetadata extends TaggingMetadata { + private static final WellKnownMimeType ROUTING_MIME_TYPE = + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING; + + public RoutingMetadata(ByteBuf content) { + super(ROUTING_MIME_TYPE.getString(), content); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java new file mode 100644 index 000000000..e22d97106 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java @@ -0,0 +1,64 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +/** + * Tagging metadata from https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + * + * @author linux_china + */ +public class TaggingMetadata implements Iterable, CompositeMetadata.Entry { + /** Tag max length in bytes */ + private static int TAG_LENGTH_MAX = 0xFF; + + private String mimeType; + private ByteBuf content; + + public TaggingMetadata(String mimeType, ByteBuf content) { + this.mimeType = mimeType; + this.content = content; + } + + public Stream stream() { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + iterator(), Spliterator.DISTINCT | Spliterator.NONNULL | Spliterator.ORDERED), + false); + } + + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + return content.readerIndex() < content.capacity(); + } + + @Override + public String next() { + int tagLength = TAG_LENGTH_MAX & content.readByte(); + if (tagLength > 0) { + return content.readSlice(tagLength).toString(StandardCharsets.UTF_8); + } else { + return ""; + } + } + }; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.mimeType; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java new file mode 100644 index 000000000..d766cf59f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java @@ -0,0 +1,76 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import java.nio.charset.StandardCharsets; +import java.util.Collection; + +/** + * A flyweight class that can be used to encode/decode tagging metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * TaggingMetadata} for an Iterator-like approach to decoding entries. + * + * @author linux_china + */ +public class TaggingMetadataCodec { + /** Tag max length in bytes */ + private static int TAG_LENGTH_MAX = 0xFF; + + /** + * create routing metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return routing metadata + */ + public static RoutingMetadata createRoutingMetadata( + ByteBufAllocator allocator, Collection tags) { + return new RoutingMetadata(createTaggingContent(allocator, tags)); + } + + /** + * create tagging metadata from composite metadata entry + * + * @param entry composite metadata entry + * @return tagging metadata + */ + public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entry) { + return new TaggingMetadata(entry.getMimeType(), entry.getContent()); + } + + /** + * create tagging metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param tags tag values + * @return Tagging Metadata + */ + public static TaggingMetadata createTaggingMetadata( + ByteBufAllocator allocator, String knownMimeType, Collection tags) { + return new TaggingMetadata(knownMimeType, createTaggingContent(allocator, tags)); + } + + /** + * create tagging content + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return tagging content + */ + public static ByteBuf createTaggingContent(ByteBufAllocator allocator, Collection tags) { + CompositeByteBuf taggingContent = allocator.compositeBuffer(); + for (String key : tags) { + int length = ByteBufUtil.utf8Bytes(key); + if (length == 0 || length > TAG_LENGTH_MAX) { + continue; + } + ByteBuf byteBuf = allocator.buffer().writeByte(length); + byteBuf.writeCharSequence(key, StandardCharsets.UTF_8); + taggingContent.addComponent(true, byteBuf); + } + return taggingContent; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java new file mode 100644 index 000000000..d276a9436 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +/** + * Represents decoded tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public final class TracingMetadata { + + final long traceIdHigh; + final long traceId; + private final boolean hasParentId; + final long parentId; + final long spanId; + final boolean isEmpty; + final boolean isNotSampled; + final boolean isSampled; + final boolean isDebug; + + TracingMetadata( + long traceIdHigh, + long traceId, + long spanId, + boolean hasParentId, + long parentId, + boolean isEmpty, + boolean isNotSampled, + boolean isSampled, + boolean isDebug) { + this.traceIdHigh = traceIdHigh; + this.traceId = traceId; + this.spanId = spanId; + this.hasParentId = hasParentId; + this.parentId = parentId; + this.isEmpty = isEmpty; + this.isNotSampled = isNotSampled; + this.isSampled = isSampled; + this.isDebug = isDebug; + } + + /** When non-zero, the trace containing this span uses 128-bit trace identifiers. */ + public long traceIdHigh() { + return traceIdHigh; + } + + /** Unique 8-byte identifier for a trace, set on all spans within it. */ + public long traceId() { + return traceId; + } + + /** Indicates if the parent's {@link #spanId} or if this the root span in a trace. */ + public final boolean hasParent() { + return hasParentId; + } + + /** Returns the parent's {@link #spanId} where zero implies absent. */ + public long parentId() { + return parentId; + } + + /** + * Unique 8-byte identifier of this span within a trace. + * + *

A span is uniquely identified in storage by ({@linkplain #traceId}, {@linkplain #spanId}). + */ + public long spanId() { + return spanId; + } + + /** Indicates that trace IDs should be accepted for tracing. */ + public boolean isSampled() { + return isSampled; + } + + /** Indicates that trace IDs should be force traced. */ + public boolean isDebug() { + return isDebug; + } + + /** Includes that there is sampling information and no trace IDs. */ + public boolean isEmpty() { + return isEmpty; + } + + /** + * Indicated that sampling decision is present. If {@code false} means that decision is unknown + * and says explicitly that {@link #isDebug()} and {@link #isSampled()} also returns {@code + * false}. + */ + public boolean isDecided() { + return isNotSampled || isDebug || isSampled; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java new file mode 100644 index 000000000..eb44956f6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java @@ -0,0 +1,172 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +/** + * Represents codes for tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public class TracingMetadataCodec { + + static final int FLAG_EXTENDED_TRACE_ID_SIZE = 0b0000_1000; + static final int FLAG_INCLUDE_PARENT_ID = 0b0000_0100; + static final int FLAG_NOT_SAMPLED = 0b0001_0000; + static final int FLAG_SAMPLED = 0b0010_0000; + static final int FLAG_DEBUG = 0b0100_0000; + static final int FLAG_IDS_SET = 0b1000_0000; + + public static ByteBuf encodeEmpty(ByteBufAllocator allocator, Flags flag) { + + return encode(allocator, true, 0, 0, false, 0, 0, false, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, + long traceIdHigh, + long traceId, + long spanId, + long parentId, + Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, parentId, true, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, long traceIdHigh, long traceId, long spanId, Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, 0, false, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, long parentId, Flags flag) { + + return encode(allocator, false, 0, traceId, false, spanId, parentId, true, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, Flags flag) { + return encode(allocator, false, 0, traceId, false, spanId, 0, false, flag); + } + + static ByteBuf encode( + ByteBufAllocator allocator, + boolean isEmpty, + long traceIdHigh, + long traceId, + boolean extendedTraceId, + long spanId, + long parentId, + boolean includesParent, + Flags flag) { + int size = + 1 + + (isEmpty + ? 0 + : (Long.BYTES + + Long.BYTES + + (extendedTraceId ? Long.BYTES : 0) + + (includesParent ? Long.BYTES : 0))); + final ByteBuf buffer = allocator.buffer(size); + + int byteFlags = 0; + switch (flag) { + case NOT_SAMPLE: + byteFlags |= FLAG_NOT_SAMPLED; + break; + case SAMPLE: + byteFlags |= FLAG_SAMPLED; + break; + case DEBUG: + byteFlags |= FLAG_DEBUG; + break; + } + + if (isEmpty) { + return buffer.writeByte(byteFlags); + } + + byteFlags |= FLAG_IDS_SET; + + if (extendedTraceId) { + byteFlags |= FLAG_EXTENDED_TRACE_ID_SIZE; + } + + if (includesParent) { + byteFlags |= FLAG_INCLUDE_PARENT_ID; + } + + buffer.writeByte(byteFlags); + + if (extendedTraceId) { + buffer.writeLong(traceIdHigh); + } + + buffer.writeLong(traceId).writeLong(spanId); + + if (includesParent) { + buffer.writeLong(parentId); + } + + return buffer; + } + + public static TracingMetadata decode(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + try { + byte flags = byteBuf.readByte(); + boolean isNotSampled = (flags & FLAG_NOT_SAMPLED) == FLAG_NOT_SAMPLED; + boolean isSampled = (flags & FLAG_SAMPLED) == FLAG_SAMPLED; + boolean isDebug = (flags & FLAG_DEBUG) == FLAG_DEBUG; + boolean isIDSet = (flags & FLAG_IDS_SET) == FLAG_IDS_SET; + + if (!isIDSet) { + return new TracingMetadata(0, 0, 0, false, 0, true, isNotSampled, isSampled, isDebug); + } + + boolean extendedTraceId = + (flags & FLAG_EXTENDED_TRACE_ID_SIZE) == FLAG_EXTENDED_TRACE_ID_SIZE; + + long traceIdHigh; + if (extendedTraceId) { + traceIdHigh = byteBuf.readLong(); + } else { + traceIdHigh = 0; + } + + long traceId = byteBuf.readLong(); + long spanId = byteBuf.readLong(); + + boolean includesParent = (flags & FLAG_INCLUDE_PARENT_ID) == FLAG_INCLUDE_PARENT_ID; + + long parentId; + if (includesParent) { + parentId = byteBuf.readLong(); + } else { + parentId = 0; + } + + return new TracingMetadata( + traceIdHigh, + traceId, + spanId, + includesParent, + parentId, + false, + isNotSampled, + isSampled, + isDebug); + } finally { + byteBuf.resetReaderIndex(); + } + } + + public enum Flags { + UNDECIDED, + NOT_SAMPLE, + SAMPLE, + DEBUG + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java new file mode 100644 index 000000000..66c98701c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownAuthType { + UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + SIMPLE("simple", (byte) 0x00), + BEARER("bearer", (byte) 0x01); + // ... reserved for future use ... + + static final WellKnownAuthType[] TYPES_BY_AUTH_ID; + static final Map TYPES_BY_AUTH_STRING; + + static { + // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum + TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); + // also prepare a Map of the types by auth string + TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); + + for (WellKnownAuthType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_AUTH_ID[value.getIdentifier()] = value; + TYPES_BY_AUTH_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownAuthType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownAuthType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_AUTH_TYPE; + } + return TYPES_BY_AUTH_ID[id]; + } + + /** + * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link + * #UNPARSEABLE_AUTH_TYPE} is returned. + * + * @param authType the looked up auth type + * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none + * matches + */ + public static WellKnownAuthType fromString(String authType) { + if (authType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used + if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { + return UNPARSEABLE_AUTH_TYPE; + } + + return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); + } + + /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java new file mode 100644 index 000000000..e78e87629 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java @@ -0,0 +1,167 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Mime Types, as defined in the eponymous extension. Such mime types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownMimeType { + UNPARSEABLE_MIME_TYPE("UNPARSEABLE_MIME_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_MIME_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + APPLICATION_AVRO("application/avro", (byte) 0x00), + APPLICATION_CBOR("application/cbor", (byte) 0x01), + APPLICATION_GRAPHQL("application/graphql", (byte) 0x02), + APPLICATION_GZIP("application/gzip", (byte) 0x03), + APPLICATION_JAVASCRIPT("application/javascript", (byte) 0x04), + APPLICATION_JSON("application/json", (byte) 0x05), + APPLICATION_OCTET_STREAM("application/octet-stream", (byte) 0x06), + APPLICATION_PDF("application/pdf", (byte) 0x07), + APPLICATION_THRIFT("application/vnd.apache.thrift.binary", (byte) 0x08), + APPLICATION_PROTOBUF("application/vnd.google.protobuf", (byte) 0x09), + APPLICATION_XML("application/xml", (byte) 0x0A), + APPLICATION_ZIP("application/zip", (byte) 0x0B), + AUDIO_AAC("audio/aac", (byte) 0x0C), + AUDIO_MP3("audio/mp3", (byte) 0x0D), + AUDIO_MP4("audio/mp4", (byte) 0x0E), + AUDIO_MPEG3("audio/mpeg3", (byte) 0x0F), + AUDIO_MPEG("audio/mpeg", (byte) 0x10), + AUDIO_OGG("audio/ogg", (byte) 0x11), + AUDIO_OPUS("audio/opus", (byte) 0x12), + AUDIO_VORBIS("audio/vorbis", (byte) 0x13), + IMAGE_BMP("image/bmp", (byte) 0x14), + IMAGE_GIF("image/gif", (byte) 0x15), + IMAGE_HEIC_SEQUENCE("image/heic-sequence", (byte) 0x16), + IMAGE_HEIC("image/heic", (byte) 0x17), + IMAGE_HEIF_SEQUENCE("image/heif-sequence", (byte) 0x18), + IMAGE_HEIF("image/heif", (byte) 0x19), + IMAGE_JPEG("image/jpeg", (byte) 0x1A), + IMAGE_PNG("image/png", (byte) 0x1B), + IMAGE_TIFF("image/tiff", (byte) 0x1C), + MULTIPART_MIXED("multipart/mixed", (byte) 0x1D), + TEXT_CSS("text/css", (byte) 0x1E), + TEXT_CSV("text/csv", (byte) 0x1F), + TEXT_HTML("text/html", (byte) 0x20), + TEXT_PLAIN("text/plain", (byte) 0x21), + TEXT_XML("text/xml", (byte) 0x22), + VIDEO_H264("video/H264", (byte) 0x23), + VIDEO_H265("video/H265", (byte) 0x24), + VIDEO_VP8("video/VP8", (byte) 0x25), + APPLICATION_HESSIAN("application/x-hessian", (byte) 0x26), + APPLICATION_JAVA_OBJECT("application/x-java-object", (byte) 0x27), + APPLICATION_CLOUDEVENTS_JSON("application/cloudevents+json", (byte) 0x28), + + // ... reserved for future use ... + MESSAGE_RSOCKET_MIMETYPE("message/x.rsocket.mime-type.v0", (byte) 0x7A), + MESSAGE_RSOCKET_ACCEPT_MIMETYPES("message/x.rsocket.accept-mime-types.v0", (byte) 0x7B), + MESSAGE_RSOCKET_AUTHENTICATION("message/x.rsocket.authentication.v0", (byte) 0x7C), + MESSAGE_RSOCKET_TRACING_ZIPKIN("message/x.rsocket.tracing-zipkin.v0", (byte) 0x7D), + MESSAGE_RSOCKET_ROUTING("message/x.rsocket.routing.v0", (byte) 0x7E), + MESSAGE_RSOCKET_COMPOSITE_METADATA("message/x.rsocket.composite-metadata.v0", (byte) 0x7F); + + static final WellKnownMimeType[] TYPES_BY_MIME_ID; + static final Map TYPES_BY_MIME_STRING; + + static { + // precompute an array of all valid mime ids, filling the blanks with the RESERVED enum + TYPES_BY_MIME_ID = new WellKnownMimeType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_MIME_ID, UNKNOWN_RESERVED_MIME_TYPE); + // also prepare a Map of the types by mime string + TYPES_BY_MIME_STRING = new HashMap<>(128); + + for (WellKnownMimeType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_MIME_ID[value.getIdentifier()] = value; + TYPES_BY_MIME_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownMimeType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownMimeType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_MIME_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_MIME_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownMimeType}, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownMimeType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_MIME_TYPE; + } + return TYPES_BY_MIME_ID[id]; + } + + /** + * Find the {@link WellKnownMimeType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownMimeType}, the {@link + * #UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeType the looked up mime type + * @return the matching {@link WellKnownMimeType}, or {@link #UNPARSEABLE_MIME_TYPE} if none + * matches + */ + public static WellKnownMimeType fromString(String mimeType) { + if (mimeType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_MIME_TYPE's text has been used + if (mimeType.equals(UNKNOWN_RESERVED_MIME_TYPE.str)) { + return UNPARSEABLE_MIME_TYPE; + } + + return TYPES_BY_MIME_STRING.getOrDefault(mimeType, UNPARSEABLE_MIME_TYPE); + } + + /** @return the byte identifier of the mime type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the mime type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java new file mode 100644 index 000000000..3fb9ae1d6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains implementations of RSocket protocol extensions related + * to the use of metadata. + */ +@NonNullApi +package io.rsocket.metadata; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/package-info.java b/rsocket-core/src/main/java/io/rsocket/package-info.java index 3e23c5ff1..6fe74fb38 100644 --- a/rsocket-core/src/main/java/io/rsocket/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/package-info.java @@ -1,17 +1,29 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ +/** + * Contains key contracts of the RSocket programming model including {@link io.rsocket.RSocket + * RSocket} for performing or handling RSocket interactions, {@link io.rsocket.SocketAcceptor + * SocketAcceptor} for declaring responders, {@link io.rsocket.Payload Payload} for access to the + * content of a payload, and others. + * + *

To connect to or start a server see {@link io.rsocket.core.RSocketConnector RSocketConnector} + * and {@link io.rsocket.core.RSocketServer RSocketServer} in {@link io.rsocket.core}. + */ +@NonNullApi package io.rsocket; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java new file mode 100644 index 000000000..9a134153d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java @@ -0,0 +1,147 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import java.util.List; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +class CompositeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor[] requestInterceptors; + + CompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { + this.requestInterceptors = requestInterceptors; + } + + @Override + public void dispose() { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + requestInterceptor.dispose(); + } + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Nullable + static RequestInterceptor create(List interceptors) { + switch (interceptors.size()) { + case 0: + return null; + case 1: + return new SafeRequestInterceptor(interceptors.get(0)); + default: + return new CompositeRequestInterceptor(interceptors.toArray(new RequestInterceptor[0])); + } + } + + static class SafeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor requestInterceptor; + + public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + + @Override + public void dispose() { + requestInterceptor.dispose(); + } + + @Override + public boolean isDisposed() { + return requestInterceptor.isDisposed(); + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java index 98a0d364c..5d3a43b03 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java @@ -1,17 +1,17 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.plugins; @@ -19,11 +19,17 @@ import io.rsocket.DuplexConnection; import java.util.function.BiFunction; -/** */ +/** + * Contract to decorate a {@link DuplexConnection} and intercept the sending and receiving of + * RSocket frames at the transport level. + */ public @FunctionalInterface interface DuplexConnectionInterceptor extends BiFunction { + enum Type { - STREAM_ZERO, + /** @deprecated since 1.1.0-M2. Will be removed in 1.2 */ + @Deprecated + SETUP, CLIENT, SERVER, SOURCE diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java new file mode 100644 index 000000000..7c9a90f54 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import reactor.util.annotation.Nullable; + +/** + * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. + * This is not intended for direct use by applications. + */ +public class InitializingInterceptorRegistry extends InterceptorRegistry { + + @Nullable + public RequestInterceptor initRequesterRequestInterceptor(RSocket rSocketRequester) { + return CompositeRequestInterceptor.create( + getRequestInterceptorsForRequester() + .stream() + .map(factory -> factory.apply(rSocketRequester)) + .collect(Collectors.toList())); + } + + @Nullable + public RequestInterceptor initResponderRequestInterceptor( + RSocket rSocketResponder, RequestInterceptor... perConnectionInterceptors) { + return CompositeRequestInterceptor.create( + Stream.concat( + Stream.of(perConnectionInterceptors), + getRequestInterceptorsForResponder() + .stream() + .map(inteptorFactory -> inteptorFactory.apply(rSocketResponder))) + .collect(Collectors.toList())); + } + + public DuplexConnection initConnection( + DuplexConnectionInterceptor.Type type, DuplexConnection connection) { + for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { + connection = interceptor.apply(type, connection); + } + return connection; + } + + public RSocket initRequester(RSocket rsocket) { + for (RSocketInterceptor interceptor : getRequesterInterceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public RSocket initResponder(RSocket rsocket) { + for (RSocketInterceptor interceptor : getResponderInterceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public SocketAcceptor initSocketAcceptor(SocketAcceptor acceptor) { + for (SocketAcceptorInterceptor interceptor : getSocketAcceptorInterceptors()) { + acceptor = interceptor.apply(acceptor); + } + return acceptor; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java new file mode 100644 index 000000000..680fb514f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -0,0 +1,160 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.RSocket; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Provides support for registering interceptors at the following levels: + * + *

    + *
  • {@link #forConnection(DuplexConnectionInterceptor)} -- transport level + *
  • {@link #forSocketAcceptor(SocketAcceptorInterceptor)} -- for accepting new connections + *
  • {@link #forRequester(RSocketInterceptor)} -- for performing of requests + *
  • {@link #forResponder(RSocketInterceptor)} -- for responding to requests + *
+ */ +public class InterceptorRegistry { + private List> requesterRequestInterceptors = + new ArrayList<>(); + private List> responderRequestInterceptors = + new ArrayList<>(); + private List requesterRSocketInterceptors = new ArrayList<>(); + private List responderRSocketInterceptors = new ArrayList<>(); + private List socketAcceptorInterceptors = new ArrayList<>(); + private List connectionInterceptors = new ArrayList<>(); + + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInRequester( + Function interceptor) { + requesterRequestInterceptors.add(interceptor); + return this; + } + + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInResponder( + Function interceptor) { + responderRequestInterceptors.add(interceptor); + return this; + } + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. + */ + public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { + requesterRSocketInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forRequester(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forRequester(Consumer> consumer) { + consumer.accept(requesterRSocketInterceptors); + return this; + } + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for resonding to + * requests. + */ + public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { + responderRSocketInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forResponder(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forResponder(Consumer> consumer) { + consumer.accept(responderRSocketInterceptors); + return this; + } + + /** + * Add a {@link SocketAcceptorInterceptor} that will intercept the accepting of new connections. + */ + public InterceptorRegistry forSocketAcceptor(SocketAcceptorInterceptor interceptor) { + socketAcceptorInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forSocketAcceptor(SocketAcceptorInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forSocketAcceptor(Consumer> consumer) { + consumer.accept(socketAcceptorInterceptors); + return this; + } + + /** Add a {@link DuplexConnectionInterceptor}. */ + public InterceptorRegistry forConnection(DuplexConnectionInterceptor interceptor) { + connectionInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forConnection(DuplexConnectionInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forConnection(Consumer> consumer) { + consumer.accept(connectionInterceptors); + return this; + } + + List> getRequestInterceptorsForRequester() { + return requesterRequestInterceptors; + } + + List> getRequestInterceptorsForResponder() { + return responderRequestInterceptors; + } + + List getRequesterInterceptors() { + return requesterRSocketInterceptors; + } + + List getResponderInterceptors() { + return responderRSocketInterceptors; + } + + List getConnectionInterceptors() { + return connectionInterceptors; + } + + List getSocketAcceptorInterceptors() { + return socketAcceptorInterceptors; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java new file mode 100644 index 000000000..d7d9742d0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +/** + * Interceptor that adds {@link Flux#limitRate(int, int)} to publishers of outbound streams that + * breaks down or aggregates demand values from the remote end (i.e. {@code REQUEST_N} frames) into + * batches of a uniform size. For example the remote may request {@code Long.MAXVALUE} or it may + * start requesting one at a time, in both cases with the limit set to 64, the publisher will see a + * demand of 64 to start and subsequent batches of 48, i.e. continuing to prefetch and refill an + * internal queue when it falls to 75% full. The high and low tide marks are configurable. + * + *

See static factory methods to create an instance for a requester or for a responder. + * + *

Note: keep in mind that the {@code limitRate} operator always uses requests + * the same request values, even if the remote requests less than the limit. For example given a + * limit of 64, if the remote requests 4, 64 will be prefetched of which 4 will be sent and 60 will + * be cached. + * + * @since 1.0 + */ +public class LimitRateInterceptor implements RSocketInterceptor { + + private final int highTide; + private final int lowTide; + private final boolean requesterProxy; + + private LimitRateInterceptor(int highTide, int lowTide, boolean requesterProxy) { + this.highTide = highTide; + this.lowTide = lowTide; + this.requesterProxy = requesterProxy; + } + + @Override + public RSocket apply(RSocket socket) { + return requesterProxy ? new RequesterProxy(socket) : new ResponderProxy(socket); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int prefetchRate) { + return forResponder(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions with more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, false); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int prefetchRate) { + return forRequester(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions with + * more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, true); + } + + /** Responder side proxy, limits response streams. */ + private class ResponderProxy extends RSocketProxy { + + ResponderProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestStream(Payload payload) { + return super.requestStream(payload).limitRate(highTide, lowTide); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(payloads).limitRate(highTide, lowTide); + } + } + + /** Requester side proxy, limits channel request stream. */ + private class RequesterProxy extends RSocketProxy { + + RequesterProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(Flux.from(payloads).limitRate(highTide, lowTide)); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java deleted file mode 100644 index 2f00cf95f..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/PluginRegistry.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -import io.rsocket.DuplexConnection; -import io.rsocket.RSocket; -import java.util.ArrayList; -import java.util.List; - -public class PluginRegistry { - private List connections = new ArrayList<>(); - private List clients = new ArrayList<>(); - private List servers = new ArrayList<>(); - - public PluginRegistry() {} - - public PluginRegistry(PluginRegistry defaults) { - this.connections.addAll(defaults.connections); - this.clients.addAll(defaults.clients); - this.servers.addAll(defaults.servers); - } - - public void addConnectionPlugin(DuplexConnectionInterceptor interceptor) { - connections.add(interceptor); - } - - public void addClientPlugin(RSocketInterceptor interceptor) { - clients.add(interceptor); - } - - public void addServerPlugin(RSocketInterceptor interceptor) { - servers.add(interceptor); - } - - public RSocket applyClient(RSocket rSocket) { - for (RSocketInterceptor i : clients) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - public RSocket applyServer(RSocket rSocket) { - for (RSocketInterceptor i : servers) { - rSocket = i.apply(rSocket); - } - - return rSocket; - } - - public DuplexConnection applyConnection( - DuplexConnectionInterceptor.Type type, DuplexConnection connection) { - for (DuplexConnectionInterceptor i : connections) { - connection = i.apply(type, connection); - } - - return connection; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java b/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java deleted file mode 100644 index 13d5db33b..000000000 --- a/rsocket-core/src/main/java/io/rsocket/plugins/Plugins.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.plugins; - -/** JVM wide plugins for RSocket */ -public class Plugins { - private static PluginRegistry DEFAULT = new PluginRegistry(); - - private Plugins() {} - - public static void interceptConnection(DuplexConnectionInterceptor interceptor) { - DEFAULT.addConnectionPlugin(interceptor); - } - - public static void interceptClient(RSocketInterceptor interceptor) { - DEFAULT.addClientPlugin(interceptor); - } - - public static void interceptServer(RSocketInterceptor interceptor) { - DEFAULT.addServerPlugin(interceptor); - } - - public static PluginRegistry defaultPlugins() { - return DEFAULT; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java index 8be627880..0cd4bb8f6 100644 --- a/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java @@ -1,17 +1,17 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.plugins; @@ -19,5 +19,10 @@ import io.rsocket.RSocket; import java.util.function.Function; -/** */ +/** + * Contract to decorate an {@link RSocket}, providing a way to intercept interactions. This can be + * applied to a {@link InterceptorRegistry#forRequester(RSocketInterceptor) requester} or {@link + * InterceptorRegistry#forResponder(RSocketInterceptor) responder} {@code RSocket} of a client or + * server. + */ public @FunctionalInterface interface RSocketInterceptor extends Function {} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java new file mode 100644 index 000000000..08131b39d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -0,0 +1,79 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.Disposable; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * Class used to track the RSocket requests lifecycles. The main difference and advantage of this + * interceptor compares to {@link RSocketInterceptor} is that it allows intercepting the initial and + * terminal phases on every individual request. + * + *

Note, if any of the invocations will rise a runtime exception, this exception will be + * caught and be propagated to {@link reactor.core.publisher.Operators#onErrorDropped(Throwable, + * Context)} + * + * @since 1.1 + */ +public interface RequestInterceptor extends Disposable { + + /** + * Method which is being invoked on successful acceptance and start of a request. + * + * @param streamId used for the request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata); + + /** + * Method which is being invoked once a successfully accepted request is terminated. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onCancel(int, FrameType)}. + * + * @param streamId used by this request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param t with which this finished has terminated. Must be one of the following signals + */ + void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t); + + /** + * Method which is being invoked once a successfully accepted request is cancelled. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onTerminate(int, FrameType, Throwable)}. + * + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param streamId used by this request + */ + void onCancel(int streamId, FrameType requestType); + + /** + * Method which is being invoked on the request rejection. This method is being called only if the + * actual request can not be started and is called instead of the {@link #onStart(int, FrameType, + * ByteBuf)} method. The reason for rejection can be one of the following: + * + *

+ * + *

    + *
  • No available {@link io.rsocket.lease.Lease} on the requester or the responder sides + *
  • Invalid {@link io.rsocket.Payload} size or format on the Requester side, so the request + * is being rejected before the actual streamId is generated + *
  • A second subscription on the ongoing Request + *
+ * + * @param rejectionReason exception which causes rejection of a particular request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onReject(Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata); +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java new file mode 100644 index 000000000..6dd850ba9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java @@ -0,0 +1,29 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.SocketAcceptor; +import java.util.function.Function; + +/** + * Contract to decorate a {@link SocketAcceptor}, providing access to connection {@code setup} + * information and the ability to also decorate the sockets for requesting and responding. + * + *

This could be used as an alternative to registering an individual "requester" {@code + * RSocketInterceptor} and "responder" {@code RSocketInterceptor}. + */ +public @FunctionalInterface interface SocketAcceptorInterceptor + extends Function {} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java new file mode 100644 index 000000000..fd9e1f01a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Contracts for interception of transports, connections, and requests in in RSocket Java. */ +@NonNullApi +package io.rsocket.plugins; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java new file mode 100644 index 000000000..ca4f5dcb4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java @@ -0,0 +1,383 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.function.Tuple2; +import reactor.util.retry.Retry; + +public class ClientRSocketSession + implements RSocketSession, + ResumeStateHolder, + CoreSubscriber> { + + private static final Logger logger = LoggerFactory.getLogger(ClientRSocketSession.class); + + final ResumableDuplexConnection resumableConnection; + final Mono> connectionFactory; + final ResumableFramesStore resumableFramesStore; + + final ByteBufAllocator allocator; + final Duration resumeSessionDuration; + final Retry retry; + final boolean cleanupStoreOnKeepAlive; + final ByteBuf resumeToken; + final String session; + final Disposable reconnectDisposable; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ClientRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; + + public ClientRSocketSession( + ByteBuf resumeToken, + ResumableDuplexConnection resumableDuplexConnection, + Mono connectionFactory, + Function>> connectionTransformer, + ResumableFramesStore resumableFramesStore, + Duration resumeSessionDuration, + Retry retry, + boolean cleanupStoreOnKeepAlive) { + this.resumeToken = resumeToken; + this.session = resumeToken.toString(CharsetUtil.UTF_8); + this.connectionFactory = + connectionFactory + .doOnDiscard( + DuplexConnection.class, + c -> { + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + c.sendErrorAndClose(connectionErrorException); + c.receive().subscribe(); + }) + .flatMap( + dc -> { + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + final long position = resumableFramesStore.framePosition(); + dc.sendFrame( + 0, + ResumeFrameCodec.encode( + dc.alloc(), + resumeToken.retain(), + // server uses this to release its cache + impliedPosition, // observed on the client side + // server uses this to check whether there is no mismatch + position // sent from the client sent + )); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeFrame[impliedPosition[{}], position[{}]] has been sent.", + session, + impliedPosition, + position); + } + + return connectionTransformer.apply(dc); + }) + .doOnDiscard(Tuple2.class, this::tryReestablishSession); + this.resumableFramesStore = resumableFramesStore; + this.allocator = resumableDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.retry = retry; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + + this.reconnectDisposable = + resumableDuplexConnection.onActiveConnectionClosed().subscribe(this::reconnect); + } + + void reconnect(int index) { + if (this.s == Operators.cancelledSubscription()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting rejected since session is closed", + session, + index); + } + return; + } + + keepAliveSupport.stop(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting to resume...", + session, + index); + } + connectionFactory + .doOnNext(this::tryReestablishSession) + .retryWhen(retry) + .timeout(resumeSessionDuration) + .subscribe(this); + } + + @Override + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); + } + + @Override + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } + } + + @Override + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Disposing", session); + } + + boolean result = Operators.terminate(S, this); + + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Sessions[isDisposed={}]", session, result); + } + + reconnectDisposable.dispose(); + resumableConnection.dispose(); + // frame store is disposed by resumable connection + // resumableFramesStore.dispose(); + + if (resumeToken.refCnt() > 0) { + resumeToken.release(); + } + } + + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); + } + + void tryReestablishSession(Tuple2 tuple2) { + if (logger.isDebugEnabled()) { + logger.debug("Active subscription is canceled {}", s == Operators.cancelledSubscription()); + } + ByteBuf shouldBeResumeOKFrame = tuple2.getT1(); + DuplexConnection nextDuplexConnection = tuple2.getT2(); + + final int streamId = FrameHeaderCodec.streamId(shouldBeResumeOKFrame); + if (streamId != 0) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + throw connectionErrorException; // throw to retry connection again + } + + final FrameType frameType = FrameHeaderCodec.nativeFrameType(shouldBeResumeOKFrame); + if (frameType == FrameType.RESUME_OK) { + // how many frames the server has received from the client + // so the client can release cached frames by this point + long remoteImpliedPos = ResumeOkFrameCodec.lastReceivedClientPos(shouldBeResumeOKFrame); + // what was the last notification from the server about number of frames being + // observed + final long position = resumableFramesStore.framePosition(); + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeOK FRAME received. ServerResumeState[remoteImpliedPosition[{}]]. ClientResumeState[impliedPosition[{}], position[{}]]", + session, + remoteImpliedPos, + impliedPosition, + position); + } + if (position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + } catch (IllegalStateException e) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + e); + } + final ConnectionErrorException t = new ConnectionErrorException(e.getMessage(), e); + + resumableConnection.dispose(nextDuplexConnection, t); + + nextDuplexConnection.sendErrorAndClose(t); + nextDuplexConnection.receive().subscribe(); + + return; + } + + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + return; + } + + keepAliveSupport.start(); + + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Session has been resumed successfully", session); + } + + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + // no need to do anything since connection resumable connection is liklly to + // be disposed + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[" + remoteImpliedPos + "]"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + } + } else if (frameType == FrameType.ERROR) { + final RuntimeException exception = Exceptions.from(0, shouldBeResumeOKFrame); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Received error frame. Terminating received connection", + session, + exception); + } + if (exception instanceof RejectedResumeException) { + resumableConnection.dispose(nextDuplexConnection, exception); + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + return; + } + + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + throw exception; // assume retryable exception + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + // no need to do anything since remote server rejected our connection completely + } + } + + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(Tuple2 objects) {} + + @Override + public void onError(Throwable t) { + if (!Operators.terminate(S, this)) { + Operators.onErrorDropped(t, currentContext()); + } + + resumableConnection.dispose(); + } + + @Override + public void onComplete() {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java new file mode 100644 index 000000000..415a77f92 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import java.time.Duration; + +public class ClientResume { + private final Duration sessionDuration; + private final ByteBuf resumeToken; + + public ClientResume(Duration sessionDuration, ByteBuf resumeToken) { + this.sessionDuration = sessionDuration; + this.resumeToken = resumeToken; + } + + public Duration sessionDuration() { + return sessionDuration; + } + + public ByteBuf resumeToken() { + return resumeToken; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java new file mode 100644 index 000000000..e23bc154b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java @@ -0,0 +1,854 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import static io.rsocket.resume.ResumableDuplexConnection.isResumableFrame; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** + * writes - n (where n is frequent, primary operation) reads - m (where m == KeepAliveFrequency) + * skip - k -> 0 (where k is the rare operation which happens after disconnection + */ +public class InMemoryResumableFramesStore extends Flux + implements ResumableFramesStore, Subscription { + + private FramesSubscriber framesSubscriber; + private static final Logger logger = LoggerFactory.getLogger(InMemoryResumableFramesStore.class); + + final Sinks.Empty disposed = Sinks.empty(); + final Queue cachedFrames; + final String side; + final String session; + final int cacheLimit; + + volatile long impliedPosition; + static final AtomicLongFieldUpdater IMPLIED_POSITION = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "impliedPosition"); + + volatile long firstAvailableFramePosition; + static final AtomicLongFieldUpdater FIRST_AVAILABLE_FRAME_POSITION = + AtomicLongFieldUpdater.newUpdater( + InMemoryResumableFramesStore.class, "firstAvailableFramePosition"); + + long remoteImpliedPosition; + + int cacheSize; + + Throwable terminal; + + CoreSubscriber actual; + CoreSubscriber pendingActual; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "state"); + + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is finalized and all related + * stores are cleaned + */ + static final long FINALIZED_FLAG = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * InMemoryResumableFramesStore#dispose()} method + */ + static final long DISPOSED_FLAG = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * FramesSubscriber#onComplete()} or {@link FramesSubscriber#onError(Throwable)} ()} methods + */ + static final long TERMINATED_FLAG = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** Flag which indicates that {@link InMemoryResumableFramesStore} has active frames consumer */ + static final long CONNECTED_FLAG = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} has no active frames consumer + * but there is a one pending + */ + static final long PENDING_CONNECTION_FLAG = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some received implied position changes from the remote + * party + */ + static final long REMOTE_IMPLIED_POSITION_CHANGED_FLAG = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some frames stored in the {@link + * io.rsocket.internal.UnboundedProcessor} which has to be cached and sent to the remote party + */ + static final long HAS_FRAME_FLAG = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore#drain(long)} has an actor which + * is currently progressing on the work. This flag should work as a guard to enter|exist into|from + * the {@link InMemoryResumableFramesStore#drain(long)} method. + */ + static final long MAX_WORK_IN_PROGRESS = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + public InMemoryResumableFramesStore(String side, ByteBuf session, int cacheSizeBytes) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); + this.cacheLimit = cacheSizeBytes; + this.cachedFrames = new ArrayDeque<>(); + } + + public Mono saveFrames(Flux frames) { + return frames + .transform( + Operators.lift( + (__, actual) -> this.framesSubscriber = new FramesSubscriber(actual, this))) + .then(); + } + + @Override + public void releaseFrames(long remoteImpliedPos) { + long lastReceivedRemoteImpliedPosition = this.remoteImpliedPosition; + if (lastReceivedRemoteImpliedPosition > remoteImpliedPos) { + throw new IllegalStateException( + "Given Remote Implied Position is behind the last received Remote Implied Position"); + } + + this.remoteImpliedPosition = remoteImpliedPos; + + final long previousState = markRemoteImpliedPositionChanged(this); + if (isFinalized(previousState) || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | REMOTE_IMPLIED_POSITION_CHANGED_FLAG); + } + + void drain(long expectedState) { + final Fuseable.QueueSubscription qs = this.framesSubscriber.qs; + final Queue cachedFrames = this.cachedFrames; + + for (; ; ) { + if (hasRemoteImpliedPositionChanged(expectedState)) { + expectedState = handlePendingRemoteImpliedPositionChanges(expectedState, cachedFrames); + } + + if (hasPendingConnection(expectedState)) { + expectedState = handlePendingConnection(expectedState, cachedFrames); + } + + if (isConnected(expectedState)) { + if (isTerminated(expectedState)) { + handleTerminated(qs, this.terminal); + } else if (isDisposed()) { + handleDisposed(); + } else if (hasFrames(expectedState)) { + handlePendingFrames(qs); + } + } + + if (isDisposed(expectedState) || isTerminated(expectedState)) { + clearAndFinalize(this); + return; + } + + expectedState = markWorkDone(this, expectedState); + if (isFinalized(expectedState)) { + return; + } + + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + long handlePendingRemoteImpliedPositionChanges(long expectedState, Queue cachedFrames) { + final long remoteImpliedPosition = this.remoteImpliedPosition; + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toDropFromCache = Math.max(0, remoteImpliedPosition - firstAvailableFramePosition); + + if (toDropFromCache > 0) { + final int droppedFromCache = dropFramesFromCache(toDropFromCache, cachedFrames); + + if (toDropFromCache > droppedFromCache) { + this.terminal = + new IllegalStateException( + String.format( + "Local and remote state disagreement: " + + "need to remove additional %d bytes, but cache is empty", + toDropFromCache)); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + if (toDropFromCache < droppedFromCache) { + this.terminal = + new IllegalStateException( + "Local and remote state disagreement: local and remote frame sizes are not equal"); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + droppedFromCache); + if (this.cacheLimit != Integer.MAX_VALUE) { + this.cacheSize -= droppedFromCache; + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Removed frames from cache to position[{}]. CacheSize[{}]", + this.side, + this.session, + this.remoteImpliedPosition, + this.cacheSize); + } + } + } + + return expectedState; + } + + void handlePendingFrames(Fuseable.QueueSubscription qs) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; + + if (empty) { + break; + } + + handleFrame(frame); + + if (!isConnected(this.state)) { + break; + } + } + } + + long handlePendingConnection(long expectedState, Queue cachedFrames) { + CoreSubscriber lastActual = null; + for (; ; ) { + final CoreSubscriber nextActual = this.pendingActual; + + if (nextActual != lastActual) { + for (final ByteBuf frame : cachedFrames) { + nextActual.onNext(frame.retainedSlice()); + } + } + + expectedState = markConnected(this, expectedState); + if (isConnected(expectedState)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Connected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + impliedPosition); + } + + this.actual = nextActual; + break; + } + + if (!hasPendingConnection(expectedState)) { + break; + } + + lastActual = nextActual; + } + return expectedState; + } + + static int dropFramesFromCache(long toRemoveBytes, Queue cache) { + int removedBytes = 0; + while (toRemoveBytes > removedBytes && cache.size() > 0) { + final ByteBuf cachedFrame = cache.poll(); + final int frameSize = cachedFrame.readableBytes(); + + cachedFrame.release(); + + removedBytes += frameSize; + } + + return removedBytes; + } + + @Override + public Flux resumeStream() { + return this; + } + + @Override + public long framePosition() { + return this.firstAvailableFramePosition; + } + + @Override + public long frameImpliedPosition() { + return this.impliedPosition & Long.MAX_VALUE; + } + + @Override + public boolean resumableFrameReceived(ByteBuf frame) { + final int frameSize = frame.readableBytes(); + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (impliedPosition < 0) { + return false; + } + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition + frameSize)) { + return true; + } + } + } + + void pauseImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition | Long.MIN_VALUE)) { + logger.debug( + "Side[{}]|Session[{}]. Paused at position[{}]", side, session, impliedPosition); + return; + } + } + } + + void resumeImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + final long restoredImpliedPosition = impliedPosition & Long.MAX_VALUE; + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, restoredImpliedPosition)) { + logger.debug( + "Side[{}]|Session[{}]. Resumed at position[{}]", + side, + session, + restoredImpliedPosition); + return; + } + } + } + + @Override + public Mono onClose() { + return disposed.asMono(); + } + + @Override + public void dispose() { + final long previousState = markDisposed(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | DISPOSED_FLAG); + } + + void clearCache() { + final Queue frames = this.cachedFrames; + this.cacheSize = 0; + + ByteBuf frame; + while ((frame = frames.poll()) != null) { + frame.release(); + } + } + + @Override + public boolean isDisposed() { + return isDisposed(this.state); + } + + void handleFrame(ByteBuf frame) { + final boolean isResumable = isResumableFrame(frame); + if (isResumable) { + handleResumableFrame(frame); + return; + } + + handleConnectionFrame(frame); + } + + void handleTerminated(Fuseable.QueueSubscription qs, @Nullable Throwable t) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; + + if (empty) { + break; + } + + handleFrame(frame); + } + if (t != null) { + this.actual.onError(t); + } else { + this.actual.onComplete(); + } + } + + void handleDisposed() { + this.actual.onError(new CancellationException("Disposed")); + } + + void handleConnectionFrame(ByteBuf frame) { + this.actual.onNext(frame); + } + + void handleResumableFrame(ByteBuf frame) { + final Queue frames = this.cachedFrames; + final int incomingFrameSize = frame.readableBytes(); + final int cacheLimit = this.cacheLimit; + + final boolean canBeStore; + int cacheSize = this.cacheSize; + if (cacheLimit != Integer.MAX_VALUE) { + final long availableSize = cacheLimit - cacheSize; + + if (availableSize < incomingFrameSize) { + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toRemoveBytes = incomingFrameSize - availableSize; + final int removedBytes = dropFramesFromCache(toRemoveBytes, frames); + + cacheSize = cacheSize - removedBytes; + canBeStore = removedBytes >= toRemoveBytes; + + if (canBeStore) { + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + removedBytes); + } else { + this.cacheSize = cacheSize; + FIRST_AVAILABLE_FRAME_POSITION.lazySet( + this, firstAvailableFramePosition + removedBytes + incomingFrameSize); + } + } else { + canBeStore = true; + } + } else { + canBeStore = true; + } + + if (canBeStore) { + frames.offer(frame); + + if (cacheLimit != Integer.MAX_VALUE) { + this.cacheSize = cacheSize + incomingFrameSize; + } + } + + this.actual.onNext(canBeStore ? frame.retainedSlice() : frame); + } + + @Override + public void request(long n) {} + + @Override + public void cancel() { + pauseImplied(); + markDisconnected(this); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Disconnected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + frameImpliedPosition()); + } + } + + @Override + public void subscribe(CoreSubscriber actual) { + resumeImplied(); + actual.onSubscribe(this); + this.pendingActual = actual; + + final long previousState = markPendingConnection(this); + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isTerminated(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | PENDING_CONNECTION_FLAG); + } + + static class FramesSubscriber + implements CoreSubscriber, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + final InMemoryResumableFramesStore parent; + + Fuseable.QueueSubscription qs; + + boolean done; + + FramesSubscriber(CoreSubscriber actual, InMemoryResumableFramesStore parent) { + this.actual = actual; + this.parent = parent; + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + if (Operators.validate(this.qs, s)) { + final Fuseable.QueueSubscription qs = (Fuseable.QueueSubscription) s; + this.qs = qs; + + final int m = qs.requestFusion(Fuseable.ANY); + + if (m != Fuseable.ASYNC) { + s.cancel(); + this.actual.onSubscribe(this); + this.actual.onError(new IllegalStateException("Source has to be ASYNC fuseable")); + return; + } + + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf byteBuf) { + final InMemoryResumableFramesStore parent = this.parent; + long previousState = InMemoryResumableFramesStore.markFrameAdded(parent); + + if (isFinalized(previousState)) { + this.qs.clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (isConnected(previousState) || hasPendingConnection(previousState)) { + parent.drain((previousState + 1) | HAS_FRAME_FLAG); + } + } + + @Override + public void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + parent.terminal = t; + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void cancel() { + if (this.done) { + return; + } + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain(previousState | TERMINATED_FLAG); + } + + @Override + public void request(long n) {} + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public Void poll() { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public void clear() {} + } + + static long markFrameAdded(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isConnected(state) || hasPendingConnection(state) || isWorkInProgress(state)) { + nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? nextState : nextState + 1; + } + + if (STATE.compareAndSet(store, state, nextState | HAS_FRAME_FLAG)) { + return state; + } + } + } + + static long markPendingConnection(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state) || isDisposed(state) || isTerminated(state)) { + return state; + } + + if (isConnected(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : state + 1; + if (STATE.compareAndSet(store, state, nextState | PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markRemoteImpliedPositionChanged(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | REMOTE_IMPLIED_POSITION_CHANGED_FLAG)) { + return state; + } + } + } + + static long markDisconnected(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + if (STATE.compareAndSet(store, state, state & ~CONNECTED_FLAG & ~PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markWorkDone(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state & ~MAX_WORK_IN_PROGRESS & ~REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markConnected(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state ^ PENDING_CONNECTION_FLAG | CONNECTED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markTerminated(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | TERMINATED_FLAG)) { + return state; + } + } + } + + static long markDisposed(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | DISPOSED_FLAG)) { + return state; + } + } + } + + static void clearAndFinalize(InMemoryResumableFramesStore store) { + final Fuseable.QueueSubscription qs = store.framesSubscriber.qs; + for (; ; ) { + final long state = store.state; + + qs.clear(); + store.clearCache(); + + if (isFinalized(state)) { + return; + } + + if (STATE.compareAndSet(store, state, state | FINALIZED_FLAG & ~MAX_WORK_IN_PROGRESS)) { + store.disposed.tryEmitEmpty(); + store.framesSubscriber.onComplete(); + return; + } + } + } + + static boolean isConnected(long state) { + return (state & CONNECTED_FLAG) == CONNECTED_FLAG; + } + + static boolean hasRemoteImpliedPositionChanged(long state) { + return (state & REMOTE_IMPLIED_POSITION_CHANGED_FLAG) == REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + } + + static boolean hasPendingConnection(long state) { + return (state & PENDING_CONNECTION_FLAG) == PENDING_CONNECTION_FLAG; + } + + static boolean hasFrames(long state) { + return (state & HAS_FRAME_FLAG) == HAS_FRAME_FLAG; + } + + static boolean isTerminated(long state) { + return (state & TERMINATED_FLAG) == TERMINATED_FLAG; + } + + static boolean isDisposed(long state) { + return (state & DISPOSED_FLAG) == DISPOSED_FLAG; + } + + static boolean isFinalized(long state) { + return (state & FINALIZED_FLAG) == FINALIZED_FLAG; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WORK_IN_PROGRESS) > 0; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java new file mode 100644 index 000000000..6dd3d5f4d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java @@ -0,0 +1,25 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.keepalive.KeepAliveSupport; +import reactor.core.Disposable; + +public interface RSocketSession extends Disposable { + + void setKeepAliveSupport(KeepAliveSupport keepAliveSupport); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java new file mode 100644 index 000000000..c8811b9b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java @@ -0,0 +1,447 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class ResumableDuplexConnection extends Flux + implements DuplexConnection, Subscription { + + static final Logger logger = LoggerFactory.getLogger(ResumableDuplexConnection.class); + + final String side; + final String session; + final ResumableFramesStore resumableFramesStore; + + final UnboundedProcessor savableFramesSender; + final Sinks.Empty onQueueClose; + final Sinks.Empty onLastConnectionClose; + final SocketAddress remoteAddress; + final Sinks.Many onConnectionClosedSink; + + CoreSubscriber receiveSubscriber; + FrameReceivingSubscriber activeReceivingSubscriber; + + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ResumableDuplexConnection.class, "state"); + + volatile DuplexConnection activeConnection; + static final AtomicReferenceFieldUpdater + ACTIVE_CONNECTION = + AtomicReferenceFieldUpdater.newUpdater( + ResumableDuplexConnection.class, DuplexConnection.class, "activeConnection"); + + int connectionIndex = 0; + + public ResumableDuplexConnection( + String side, + ByteBuf session, + DuplexConnection initialConnection, + ResumableFramesStore resumableFramesStore) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); + this.onConnectionClosedSink = Sinks.unsafe().many().unicast().onBackpressureBuffer(); + this.resumableFramesStore = resumableFramesStore; + this.onQueueClose = Sinks.unsafe().empty(); + this.onLastConnectionClose = Sinks.unsafe().empty(); + this.savableFramesSender = new UnboundedProcessor(onQueueClose::tryEmitEmpty); + this.remoteAddress = initialConnection.remoteAddress(); + + resumableFramesStore.saveFrames(savableFramesSender).subscribe(); + + ACTIVE_CONNECTION.lazySet(this, initialConnection); + } + + public boolean connect(DuplexConnection nextConnection) { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE + && ACTIVE_CONNECTION.compareAndSet(this, activeConnection, nextConnection)) { + + if (!activeConnection.isDisposed()) { + activeConnection.sendErrorAndClose( + new ConnectionErrorException("Connection unexpectedly replaced")); + } + + initConnection(nextConnection); + + return true; + } else { + return false; + } + } + + void initConnection(DuplexConnection nextConnection) { + final int nextConnectionIndex = this.connectionIndex + 1; + final FrameReceivingSubscriber frameReceivingSubscriber = + new FrameReceivingSubscriber(side, resumableFramesStore, receiveSubscriber); + + this.connectionIndex = nextConnectionIndex; + this.activeReceivingSubscriber = frameReceivingSubscriber; + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Connecting", side, session, connectionIndex); + } + + final Disposable resumeStreamSubscription = + resumableFramesStore + .resumeStream() + .subscribe( + f -> nextConnection.sendFrame(FrameHeaderCodec.streamId(f), f), + t -> { + dispose(nextConnection, t); + nextConnection.sendErrorAndClose(new ConnectionErrorException(t.getMessage(), t)); + }, + () -> { + final ConnectionErrorException e = + new ConnectionErrorException("Connection Closed Unexpectedly"); + dispose(nextConnection, e); + nextConnection.sendErrorAndClose(e); + }); + nextConnection.receive().subscribe(frameReceivingSubscriber); + nextConnection + .onClose() + .doFinally( + __ -> { + frameReceivingSubscriber.dispose(); + resumeStreamSubscription.dispose(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Disconnected", + side, + session, + connectionIndex); + } + Sinks.EmitResult result = onConnectionClosedSink.tryEmitNext(nextConnectionIndex); + if (!result.equals(Sinks.EmitResult.OK)) { + logger.error( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Failed to notify session of closed connection: {}", + side, + session, + connectionIndex, + result); + } + }) + .subscribe(); + } + + public void disconnect() { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE && !activeConnection.isDisposed()) { + activeConnection.dispose(); + } + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + savableFramesSender.tryEmitPrioritized(frame); + } else { + savableFramesSender.tryEmitNormal(frame); + } + } + + /** + * Publisher for a sequence of integers starting at 1, with each next number emitted when the + * currently active connection is closed and should be resumed. The Publisher never emits an error + * and completes when the connection is disposed and not resumed. + */ + Flux onActiveConnectionClosed() { + return onConnectionClosedSink.asFlux(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException rSocketErrorException) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + + savableFramesSender.tryEmitFinal( + ErrorFrameCodec.encode(activeConnection.alloc(), 0, rSocketErrorException)); + + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + + final Throwable cause = rSocketErrorException.getCause(); + if (cause == null) { + onLastConnectionClose.tryEmitEmpty(); + } else { + onLastConnectionClose.tryEmitError(cause); + } + }); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public ByteBufAllocator alloc() { + return activeConnection.alloc(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError( + onQueueClose.asMono(), resumableFramesStore.onClose(), onLastConnectionClose.asMono()); + } + + @Override + public void dispose() { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + savableFramesSender.onComplete(); + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }); + } + + void dispose(DuplexConnection nextConnection, @Nullable Throwable e) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + savableFramesSender.onComplete(); + nextConnection + .onClose() + .subscribe( + null, + t -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }, + () -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onQueueClose.scan(Scannable.Attr.TERMINATED) + || onQueueClose.scan(Scannable.Attr.CANCELLED); + } + + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + + @Override + public void request(long n) { + if (state == 1 && STATE.compareAndSet(this, 1, 2)) { + // happens for the very first time with the initial connection + initConnection(this.activeConnection); + } + } + + @Override + public void cancel() { + dispose(); + } + + @Override + public void subscribe(CoreSubscriber receiverSubscriber) { + if (state == 0 && STATE.compareAndSet(this, 0, 1)) { + receiveSubscriber = receiverSubscriber; + receiverSubscriber.onSubscribe(this); + } + } + + static boolean isResumableFrame(ByteBuf frame) { + return FrameHeaderCodec.streamId(frame) != 0; + } + + @Override + public String toString() { + return "ResumableDuplexConnection{" + + "side='" + + side + + '\'' + + ", session='" + + session + + '\'' + + ", remoteAddress=" + + remoteAddress + + ", state=" + + state + + ", activeConnection=" + + activeConnection + + ", connectionIndex=" + + connectionIndex + + '}'; + } + + private static final class DisposedConnection implements DuplexConnection { + + static final DisposedConnection INSTANCE = new DisposedConnection(); + + private DisposedConnection() {} + + @Override + public void dispose() {} + + @Override + public Mono onClose() { + return Mono.never(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) {} + + @Override + public Flux receive() { + return Flux.never(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) {} + + @Override + public ByteBufAllocator alloc() { + return ByteBufAllocator.DEFAULT; + } + + @Override + @SuppressWarnings("ConstantConditions") + public SocketAddress remoteAddress() { + return null; + } + } + + private static final class FrameReceivingSubscriber + implements CoreSubscriber, Disposable { + + final ResumableFramesStore resumableFramesStore; + final CoreSubscriber actual; + final String tag; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + FrameReceivingSubscriber.class, Subscription.class, "s"); + + boolean cancelled; + + private FrameReceivingSubscriber( + String tag, ResumableFramesStore store, CoreSubscriber actual) { + this.tag = tag; + this.resumableFramesStore = store; + this.actual = actual; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(ByteBuf frame) { + if (cancelled || s == Operators.cancelledSubscription()) { + return; + } + + if (isResumableFrame(frame)) { + if (resumableFramesStore.resumableFrameReceived(frame)) { + actual.onNext(frame); + } + return; + } + + actual.onNext(frame); + } + + @Override + public void onError(Throwable t) { + Operators.set(S, this, Operators.cancelledSubscription()); + } + + @Override + public void onComplete() { + Operators.set(S, this, Operators.cancelledSubscription()); + } + + @Override + public void dispose() { + cancelled = true; + Operators.terminate(S, this); + } + + @Override + public boolean isDisposed() { + return cancelled || s == Operators.cancelledSubscription(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java new file mode 100644 index 000000000..80d9a36dd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** Store for resumable frames */ +public interface ResumableFramesStore extends Closeable { + + /** + * Save resumable frames for potential resumption + * + * @param frames {@link Flux} of resumable frames + * @return {@link Mono} which completes once all resume frames are written + */ + Mono saveFrames(Flux frames); + + /** Release frames from tail of the store up to remote implied position */ + void releaseFrames(long remoteImpliedPos); + + /** + * @return {@link Flux} of frames from store tail to head. It should terminate with error if + * frames are not continuous + */ + Flux resumeStream(); + + /** @return Local frame position as defined by RSocket protocol */ + long framePosition(); + + /** @return Implied frame position as defined by RSocket protocol */ + long frameImpliedPosition(); + + /** + * Received resumable frame as defined by RSocket protocol. Implementation must increment frame + * implied position + * + * @return {@code true} if information about the frame has been stored + */ + boolean resumableFrameReceived(ByteBuf frame); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeCache.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeCache.java deleted file mode 100644 index 5550ce47c..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeCache.java +++ /dev/null @@ -1,114 +0,0 @@ -package io.rsocket.resume; - -import io.rsocket.Frame; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import reactor.core.publisher.Flux; - -public class ResumeCache { - private final ResumePositionCounter strategy; - private final int maxBufferSize; - - private final LinkedHashMap frames = new LinkedHashMap<>(); - private int lastRemotePosition = 0; - private int currentPosition = 0; - private int bufferSize; - - public ResumeCache(ResumePositionCounter strategy, int maxBufferSize) { - this.strategy = strategy; - this.maxBufferSize = maxBufferSize; - } - - public void updateRemotePosition(int remotePosition) { - if (remotePosition > currentPosition) { - throw new IllegalStateException( - "Remote ahead of " + lastRemotePosition + " , expected " + remotePosition); - } - - if (remotePosition == lastRemotePosition) { - return; - } - - if (remotePosition < lastRemotePosition) { - throw new IllegalStateException( - "Remote position moved back from " + lastRemotePosition + " to " + remotePosition); - } - - lastRemotePosition = remotePosition; - - Iterator> positions = frames.entrySet().iterator(); - - while (positions.hasNext()) { - Map.Entry cachePosition = positions.next(); - - if (cachePosition.getKey() <= remotePosition) { - positions.remove(); - bufferSize -= strategy.cost(cachePosition.getValue()); - cachePosition.getValue().release(); - } - - // TODO check for a bad position - } - } - - public void sent(Frame frame) { - if (ResumeUtil.isTracked(frame)) { - frames.put(currentPosition, frame.copy()); - bufferSize += strategy.cost(frame); - - currentPosition += ResumeUtil.offset(frame); - - if (frames.size() > maxBufferSize) { - Frame f = frames.remove(first(frames)); - bufferSize -= strategy.cost(f); - } - } - } - - private int first(LinkedHashMap frames) { - return frames.keySet().iterator().next(); - } - - public Flux resend(int remotePosition) { - updateRemotePosition(remotePosition); - - if (remotePosition == currentPosition) { - return Flux.empty(); - } - - List resend = new ArrayList<>(); - - for (Map.Entry cachePosition : frames.entrySet()) { - if (remotePosition < cachePosition.getKey()) { - resend.add(cachePosition.getValue()); - } - - // TODO error handling - } - - return Flux.fromIterable(resend); - } - - public int getCurrentPosition() { - return currentPosition; - } - - public int getRemotePosition() { - return lastRemotePosition; - } - - public int getEarliestResendPosition() { - if (frames.isEmpty()) { - return currentPosition; - } else { - return first(frames); - } - } - - public int size() { - return bufferSize; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumePositionCounter.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumePositionCounter.java deleted file mode 100644 index 8d4caa251..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumePositionCounter.java +++ /dev/null @@ -1,19 +0,0 @@ -package io.rsocket.resume; - -import io.rsocket.Frame; - -/** - * Calculates the cost of a Frame when stored in the ResumeCache. Two obvious and provided - * strategies are simple frame counts and size in bytes. - */ -public interface ResumePositionCounter { - int cost(Frame f); - - static ResumePositionCounter size() { - return ResumeUtil::offset; - } - - static ResumePositionCounter frames() { - return f -> 1; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java new file mode 100644 index 000000000..1fae24b07 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +class ResumeStateException extends RuntimeException { + private static final long serialVersionUID = -5393753463377588732L; + private final long localPos; + private final long localImpliedPos; + private final long remotePos; + private final long remoteImpliedPos; + + public ResumeStateException( + long localPos, long localImpliedPos, long remotePos, long remoteImpliedPos) { + this.localPos = localPos; + this.localImpliedPos = localImpliedPos; + this.remotePos = remotePos; + this.remoteImpliedPos = remoteImpliedPos; + } + + public long getLocalPos() { + return localPos; + } + + public long getLocalImpliedPos() { + return localImpliedPos; + } + + public long getRemotePos() { + return remotePos; + } + + public long getRemoteImpliedPos() { + return remoteImpliedPos; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java new file mode 100644 index 000000000..31687a24b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +public interface ResumeStateHolder { + + long impliedPosition(); + + void onImpliedPosition(long remoteImpliedPos); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeToken.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeToken.java deleted file mode 100644 index e1764f0f3..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeToken.java +++ /dev/null @@ -1,54 +0,0 @@ -package io.rsocket.resume; - -import io.netty.buffer.ByteBufUtil; -import java.nio.ByteBuffer; -import java.util.Arrays; -import java.util.UUID; - -public final class ResumeToken { - // TODO consider best format to store this - private final byte[] resumeToken; - - protected ResumeToken(byte[] resumeToken) { - this.resumeToken = resumeToken; - } - - public static ResumeToken bytes(byte[] token) { - return new ResumeToken(token); - } - - public static ResumeToken generate() { - return new ResumeToken(getBytesFromUUID(UUID.randomUUID())); - } - - static byte[] getBytesFromUUID(UUID uuid) { - ByteBuffer bb = ByteBuffer.wrap(new byte[16]); - bb.putLong(uuid.getMostSignificantBits()); - bb.putLong(uuid.getLeastSignificantBits()); - - return bb.array(); - } - - @Override - public int hashCode() { - return Arrays.hashCode(resumeToken); - } - - @Override - public boolean equals(Object obj) { - if (obj instanceof ResumeToken) { - return Arrays.equals(resumeToken, ((ResumeToken) obj).resumeToken); - } - - return false; - } - - @Override - public String toString() { - return ByteBufUtil.hexDump(resumeToken); - } - - public byte[] toByteArray() { - return resumeToken; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeUtil.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeUtil.java deleted file mode 100644 index 093120357..000000000 --- a/rsocket-core/src/main/java/io/rsocket/resume/ResumeUtil.java +++ /dev/null @@ -1,38 +0,0 @@ -package io.rsocket.resume; - -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.frame.FrameHeaderFlyweight; - -public class ResumeUtil { - public static boolean isTracked(FrameType frameType) { - switch (frameType) { - case REQUEST_CHANNEL: - case REQUEST_STREAM: - case REQUEST_RESPONSE: - case FIRE_AND_FORGET: - // case METADATA_PUSH: - case REQUEST_N: - case CANCEL: - case ERROR: - case PAYLOAD: - return true; - default: - return false; - } - } - - public static boolean isTracked(Frame frame) { - return isTracked(frame.getType()); - } - - public static int offset(Frame frame) { - int length = frame.content().readableBytes(); - - if (length < FrameHeaderFlyweight.FRAME_HEADER_LENGTH) { - throw new IllegalStateException("invalid frame"); - } - - return length - FrameHeaderFlyweight.FRAME_LENGTH_SIZE; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java new file mode 100644 index 000000000..ad1b38375 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java @@ -0,0 +1,301 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import java.time.Duration; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.concurrent.Queues; + +public class ServerRSocketSession + implements RSocketSession, ResumeStateHolder, CoreSubscriber { + private static final Logger logger = LoggerFactory.getLogger(ServerRSocketSession.class); + + final ResumableDuplexConnection resumableConnection; + final Duration resumeSessionDuration; + final ResumableFramesStore resumableFramesStore; + final String session; + final ByteBufAllocator allocator; + final boolean cleanupStoreOnKeepAlive; + + /** + * All incoming connections with the Resume intent are enqueued in this queue. Such an approach + * ensure that the new connection will affect the resumption state anyhow until the previous + * (active) connection is finally closed + */ + final Queue connectionsQueue; + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ServerRSocketSession.class, "wip"); + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ServerRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; + + public ServerRSocketSession( + ByteBuf session, + ResumableDuplexConnection resumableDuplexConnection, + DuplexConnection initialDuplexConnection, + ResumableFramesStore resumableFramesStore, + Duration resumeSessionDuration, + boolean cleanupStoreOnKeepAlive) { + this.session = session.toString(CharsetUtil.UTF_8); + this.allocator = initialDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.resumableFramesStore = resumableFramesStore; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + this.connectionsQueue = Queues.unboundedMultiproducer().get(); + + WIP.lazySet(this, 1); + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + resumableDuplexConnection.onActiveConnectionClosed().subscribe(__ -> tryTimeoutSession()); + } + + void tryTimeoutSession() { + keepAliveSupport.stop(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Connection is lost. Trying to timeout the active session", + session); + } + + Mono.delay(resumeSessionDuration).subscribe(this); + + if (WIP.decrementAndGet(this) == 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } + } + + public void resumeWith(ByteBuf resumeFrame, DuplexConnection nextDuplexConnection) { + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. New DuplexConnection received.", session); + } + + long remotePos = ResumeFrameCodec.firstAvailableClientPos(resumeFrame); + long remoteImpliedPos = ResumeFrameCodec.lastReceivedServerPos(resumeFrame); + + connectionsQueue.offer(() -> doResume(remotePos, remoteImpliedPos, nextDuplexConnection)); + + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } + } + + void doResume(long remotePos, long remoteImpliedPos, DuplexConnection nextDuplexConnection) { + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + return; + } + + long impliedPosition = resumableFramesStore.frameImpliedPosition(); + long position = resumableFramesStore.framePosition(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Resume FRAME received. ServerResumeState[impliedPosition[{}], position[{}]]. ClientResumeState[remoteImpliedPosition[{}], remotePosition[{}]]", + session, + impliedPosition, + position, + remoteImpliedPos, + remotePos); + } + + if (remotePos <= impliedPosition && position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + nextDuplexConnection.sendFrame(0, ResumeOkFrameCodec.encode(allocator, impliedPosition)); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. ResumeOKFrame[impliedPosition[{}]] has been sent", + session, + impliedPosition); + } + } catch (Throwable t) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + t); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException(t.getMessage(), t); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + return; + } + + keepAliveSupport.start(); + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Session has been resumed successfully", session); + } + + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + // resumableConnection is likely to be disposed at this stage. Thus we have + // nothing to do + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}] and RemotePosition[{}] to be less or equal to LocalImpliedPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position, + remotePos, + impliedPosition); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException( + String.format( + "resumption_pos=[ remote: { pos: %d, impliedPos: %d }, local: { pos: %d, impliedPos: %d }]", + remotePos, remoteImpliedPos, position, impliedPosition)); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + } + } + + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } + } + } + + @Override + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); + } + + @Override + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(Long aLong) { + if (!Operators.terminate(S, this)) { + return; + } + + resumableConnection.dispose(); + } + + @Override + public void onComplete() {} + + @Override + public void onError(Throwable t) {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; + } + + @Override + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Disposing session", session); + } + Operators.terminate(S, this); + resumableConnection.dispose(); + } + + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java new file mode 100644 index 000000000..736d7c77c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java @@ -0,0 +1,70 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.util.annotation.Nullable; + +public class SessionManager { + static final Logger logger = LoggerFactory.getLogger(SessionManager.class); + + private volatile boolean isDisposed; + private final Map sessions = new ConcurrentHashMap<>(); + + public ServerRSocketSession save(ServerRSocketSession session, ByteBuf resumeToken) { + if (isDisposed) { + session.dispose(); + } else { + final String token = resumeToken.toString(CharsetUtil.UTF_8); + session + .resumableConnection + .onClose() + .doFinally( + __ -> { + logger.debug( + "ResumableConnection has been closed. Removing associated session {" + + token + + "}"); + if (isDisposed || sessions.get(token) == session) { + sessions.remove(token); + } + }) + .subscribe(); + ServerRSocketSession prevSession = sessions.remove(token); + if (prevSession != null) { + prevSession.dispose(); + } + sessions.put(token, session); + } + return session; + } + + @Nullable + public ServerRSocketSession get(ByteBuf resumeToken) { + return sessions.get(resumeToken.toString(CharsetUtil.UTF_8)); + } + + public void dispose() { + isDisposed = true; + sessions.values().forEach(ServerRSocketSession::dispose); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/package-info.java b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java new file mode 100644 index 000000000..98744386a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java @@ -0,0 +1,27 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains support classes for the RSocket resume capability. + * + * @see Resuming + * Operation + */ +@NonNullApi +package io.rsocket.resume; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java index 7caee6299..3b8f624aa 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -23,10 +23,9 @@ public interface ClientTransport extends Transport { /** - * Returns a {@code Publisher}, every subscription to which returns a single {@code - * DuplexConnection}. + * Return a {@code Mono} that connects for each subscriber. * - * @return {@code Publisher}, every subscription returns a single {@code DuplexConnection}. + * @since 1.0.1 */ Mono connect(); } diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java index d89205540..92a9502a4 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -26,10 +26,11 @@ public interface ServerTransport extends Transport { /** - * Starts this server. + * Start this server. * - * @param acceptor An acceptor to process a newly accepted {@code DuplexConnection} - * @return A handle to retrieve information about a started server. + * @param acceptor to process a newly accepted connections with + * @return A handle for information about and control over the server. + * @since 1.0.1 */ Mono start(ConnectionAcceptor acceptor); diff --git a/rsocket-core/src/main/java/io/rsocket/transport/Transport.java b/rsocket-core/src/main/java/io/rsocket/transport/Transport.java index efa997fb5..39386337c 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/Transport.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/Transport.java @@ -1,20 +1,37 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.transport; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; + /** */ -public interface Transport {} +public interface Transport { + + /** + * Configurations that exposes the maximum frame size that a {@link DuplexConnection} can bring up + * to RSocket level. + * + *

This number should not exist the 16,777,215 (maximum frame size specified by RSocket spec) + * + * @return return maximum configured frame size limit + */ + default int maxFrameLength() { + return FRAME_LENGTH_MASK; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java b/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java deleted file mode 100644 index ddaed6ebf..000000000 --- a/rsocket-core/src/main/java/io/rsocket/transport/TransportHeaderAware.java +++ /dev/null @@ -1,12 +0,0 @@ -package io.rsocket.transport; - -import java.util.Map; -import java.util.function.Supplier; - -/** - * Extension interface to support Transports with headers at the transport layer, e.g. Websockets, - * Http2. - */ -public interface TransportHeaderAware { - void setTransportHeaders(Supplier> transportHeaders); -} diff --git a/rsocket-core/src/main/java/io/rsocket/transport/package-info.java b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java index af03823c1..00536122a 100644 --- a/rsocket-core/src/main/java/io/rsocket/transport/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java @@ -1,18 +1,21 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** Client and server transport contracts for pluggable transports. */ +@NonNullApi package io.rsocket.transport; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java b/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java deleted file mode 100644 index 05d0c4add..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriHandler.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.Optional; -import java.util.ServiceLoader; - -/** - * URI to {@link ClientTransport} or {@link ServerTransport}. Should return a non empty value only - * when the URI is unambiguously mapped to a particular transport, either by a standardised - * implementation or via some flag in the URI to indicate a choice. - */ -public interface UriHandler { - static ServiceLoader loadServices() { - return ServiceLoader.load(UriHandler.class); - } - - default Optional buildClient(URI uri) { - return Optional.empty(); - } - - default Optional buildServer(URI uri) { - return Optional.empty(); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java b/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java deleted file mode 100644 index 276c87310..000000000 --- a/rsocket-core/src/main/java/io/rsocket/uri/UriTransportRegistry.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static io.rsocket.uri.UriHandler.loadServices; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import java.net.URI; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.ServiceLoader; -import reactor.core.publisher.Mono; - -/** - * Registry for looking up transports by URI. - * - *

Uses the Jar Services mechanism with services defined by {@link UriHandler}. - */ -public class UriTransportRegistry { - private static final ClientTransport FAILED_CLIENT_LOOKUP = - () -> Mono.error(new UnsupportedOperationException()); - private static final ServerTransport FAILED_SERVER_LOOKUP = - acceptor -> Mono.error(new UnsupportedOperationException()); - - private List handlers; - - public UriTransportRegistry(ServiceLoader services) { - handlers = new ArrayList<>(); - services.forEach(handlers::add); - } - - public static UriTransportRegistry fromServices() { - ServiceLoader services = loadServices(); - - return new UriTransportRegistry(services); - } - - public static ClientTransport clientForUri(String uri) { - return UriTransportRegistry.fromServices().findClient(uri); - } - - private ClientTransport findClient(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildClient(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_CLIENT_LOOKUP; - } - - public static ServerTransport serverForUri(String uri) { - return UriTransportRegistry.fromServices().findServer(uri); - } - - private ServerTransport findServer(String uriString) { - URI uri = URI.create(uriString); - - for (UriHandler h : handlers) { - Optional r = h.buildServer(uri); - if (r.isPresent()) { - return r.get(); - } - } - - return FAILED_SERVER_LOOKUP; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java new file mode 100644 index 000000000..12e0b60dc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -0,0 +1,219 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.Recycler; +import io.netty.util.Recycler.Handle; +import io.rsocket.Payload; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import reactor.util.annotation.Nullable; + +public final class ByteBufPayload extends AbstractReferenceCounted implements Payload { + private static final Recycler RECYCLER = + new Recycler() { + protected ByteBufPayload newObject(Handle handle) { + return new ByteBufPayload(handle); + } + }; + + private final Handle handle; + private ByteBuf data; + private ByteBuf metadata; + + private ByteBufPayload(final Handle handle) { + this.handle = handle; + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data) { + return create(ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, data), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata) { + return create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, data), + metadata == null ? null : ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, metadata)); + } + + public static Payload create(CharSequence data, Charset dataCharset) { + return create( + ByteBufUtil.encodeString(ByteBufAllocator.DEFAULT, CharBuffer.wrap(data), dataCharset), + null); + } + + public static Payload create( + CharSequence data, + Charset dataCharset, + @Nullable CharSequence metadata, + Charset metadataCharset) { + return create( + ByteBufUtil.encodeString(ByteBufAllocator.DEFAULT, CharBuffer.wrap(data), dataCharset), + metadata == null + ? null + : ByteBufUtil.encodeString( + ByteBufAllocator.DEFAULT, CharBuffer.wrap(metadata), metadataCharset)); + } + + public static Payload create(byte[] data) { + return create(Unpooled.wrappedBuffer(data), null); + } + + public static Payload create(byte[] data, @Nullable byte[] metadata) { + return create( + Unpooled.wrappedBuffer(data), metadata == null ? null : Unpooled.wrappedBuffer(metadata)); + } + + public static Payload create(ByteBuffer data) { + return create(Unpooled.wrappedBuffer(data), null); + } + + public static Payload create(ByteBuffer data, @Nullable ByteBuffer metadata) { + return create( + Unpooled.wrappedBuffer(data), metadata == null ? null : Unpooled.wrappedBuffer(metadata)); + } + + public static Payload create(ByteBuf data) { + return create(data, null); + } + + public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { + ByteBufPayload payload = RECYCLER.get(); + payload.data = data; + payload.metadata = metadata; + // ensure data and metadata is set before refCnt change + payload.setRefCnt(1); + return payload; + } + + public static Payload create(Payload payload) { + return create( + payload.sliceData().retain(), + payload.hasMetadata() ? payload.sliceMetadata().retain() : null); + } + + @Override + public boolean hasMetadata() { + ensureAccessible(); + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + ensureAccessible(); + return data; + } + + @Override + public ByteBuf metadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + ensureAccessible(); + return data.slice(); + } + + @Override + public ByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public ByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public ByteBufPayload touch() { + ensureAccessible(); + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public ByteBufPayload touch(Object hint) { + ensureAccessible(); + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + data = null; + if (metadata != null) { + metadata.release(); + metadata = null; + } + handle.recycle(this); + } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link ByteBufPayload#ensureAccessible()} to try to guard against using the + * buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java new file mode 100644 index 000000000..328fb8435 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java @@ -0,0 +1,210 @@ +package io.rsocket.util; + +import static io.netty.util.internal.StringUtil.isSurrogate; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.MathUtil; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.util.Arrays; + +public class CharByteBufUtil { + + private static final byte WRITE_UTF_UNKNOWN = (byte) '?'; + + private CharByteBufUtil() {} + + /** + * Returns the exact bytes length of UTF8 character sequence. + * + *

This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[])}. + */ + public static int utf8Bytes(final char[] seq) { + return utf8ByteCount(seq, 0, seq.length); + } + + /** + * This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[], int, + * int)}. + */ + public static int utf8Bytes(final char[] seq, int start, int end) { + return utf8ByteCount(checkCharSequenceBounds(seq, start, end), start, end); + } + + private static int utf8ByteCount(final char[] seq, int start, int end) { + int i = start; + // ASCII fast path + while (i < end && seq[i] < 0x80) { + ++i; + } + // !ASCII is packed in a separate method to let the ASCII case be smaller + return i < end ? (i - start) + utf8BytesNonAscii(seq, i, end) : i - start; + } + + private static int utf8BytesNonAscii(final char[] seq, final int start, final int end) { + int encodedLength = 0; + for (int i = start; i < end; i++) { + final char c = seq[i]; + // making it 100% branchless isn't rewarding due to the many bit operations necessary! + if (c < 0x800) { + // branchless version of: (c <= 127 ? 0:1) + 1 + encodedLength += ((0x7f - c) >>> 31) + 1; + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + encodedLength++; + // WRITE_UTF_UNKNOWN + continue; + } + final char c2; + try { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. + c2 = seq[++i]; + } catch (IndexOutOfBoundsException ignored) { + encodedLength++; + // WRITE_UTF_UNKNOWN + break; + } + if (!Character.isLowSurrogate(c2)) { + // WRITE_UTF_UNKNOWN + (Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2) + encodedLength += 2; + continue; + } + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + encodedLength += 4; + } else { + encodedLength += 3; + } + } + return encodedLength; + } + + private static char[] checkCharSequenceBounds(char[] seq, int start, int end) { + if (MathUtil.isOutOfBounds(start, end - start, seq.length)) { + throw new IndexOutOfBoundsException( + "expected: 0 <= start(" + + start + + ") <= end (" + + end + + ") <= seq.length(" + + seq.length + + ')'); + } + return seq; + } + + /** + * Encode a {@code char[]} in UTF-8 and write it + * into {@link ByteBuf}. + * + *

This method returns the actual number of bytes written. + */ + public static int writeUtf8(ByteBuf buf, char[] seq) { + return writeUtf8(buf, seq, 0, seq.length); + } + + /** + * Equivalent to {@link #writeUtf8(ByteBuf, char[]) writeUtf8(buf, seq.subSequence(start, end), + * reserveBytes)} but avoids subsequence object allocation if possible. + * + * @return actual number of bytes written + */ + public static int writeUtf8(ByteBuf buf, char[] seq, int start, int end) { + return writeUtf8(buf, buf.writerIndex(), checkCharSequenceBounds(seq, start, end), start, end); + } + + // Fast-Path implementation + static int writeUtf8(ByteBuf buffer, int writerIndex, char[] seq, int start, int end) { + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = start; i < end; i++) { + char c = seq[i]; + if (c < 0x80) { + buffer.setByte(writerIndex++, (byte) c); + } else if (c < 0x800) { + buffer.setByte(writerIndex++, (byte) (0xc0 | (c >> 6))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + continue; + } + final char c2; + if (seq.length > ++i) { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. If an IndexOutOfBoundsException is thrown we + // will + // re-throw a more informative exception describing the problem. + c2 = seq[i]; + } else { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + break; + } + // Extra method to allow inlining the rest of writeUtf8 which is the most likely code path. + writerIndex = writeUtf8Surrogate(buffer, writerIndex, c, c2); + } else { + buffer.setByte(writerIndex++, (byte) (0xe0 | (c >> 12))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } + } + buffer.writerIndex(writerIndex); + return writerIndex - oldWriterIndex; + } + + private static int writeUtf8Surrogate(ByteBuf buffer, int writerIndex, char c, char c2) { + if (!Character.isLowSurrogate(c2)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + buffer.setByte(writerIndex++, Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); + return writerIndex; + } + int codePoint = Character.toCodePoint(c, c2); + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer.setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + return writerIndex; + } + + public static char[] readUtf8(ByteBuf byteBuf, int length) { + CharsetDecoder charsetDecoder = CharsetUtil.UTF_8.newDecoder(); + int en = (int) (length * (double) charsetDecoder.maxCharsPerByte()); + char[] ca = new char[en]; + + CharBuffer charBuffer = CharBuffer.wrap(ca); + ByteBuffer byteBuffer = + byteBuf.nioBufferCount() == 1 + ? byteBuf.internalNioBuffer(byteBuf.readerIndex(), length) + : byteBuf.nioBuffer(byteBuf.readerIndex(), length); + byteBuffer.mark(); + try { + CoderResult cr = charsetDecoder.decode(byteBuffer, charBuffer, true); + if (!cr.isUnderflow()) cr.throwException(); + cr = charsetDecoder.flush(charBuffer); + if (!cr.isUnderflow()) cr.throwException(); + + byteBuffer.reset(); + byteBuf.skipBytes(length); + + return safeTrim(charBuffer.array(), charBuffer.position()); + } catch (CharacterCodingException x) { + // Substitution is always enabled, + // so this shouldn't happen + throw new IllegalStateException("unable to decode char array from the given buffer", x); + } + } + + private static char[] safeTrim(char[] ca, int len) { + if (len == ca.length) return ca; + else return Arrays.copyOf(ca, len); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/Clock.java b/rsocket-core/src/main/java/io/rsocket/util/Clock.java index 9b7704554..4a34c988f 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/Clock.java +++ b/rsocket-core/src/main/java/io/rsocket/util/Clock.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.util; import java.util.concurrent.TimeUnit; diff --git a/rsocket-core/src/main/java/io/rsocket/util/CloseableAdapter.java b/rsocket-core/src/main/java/io/rsocket/util/CloseableAdapter.java deleted file mode 100644 index a4efce8b5..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/CloseableAdapter.java +++ /dev/null @@ -1,29 +0,0 @@ -package io.rsocket.util; - -import io.rsocket.Closeable; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -public class CloseableAdapter implements Closeable { - private final MonoProcessor onClose = MonoProcessor.create(); - private Runnable closeFunction; - - public CloseableAdapter(Runnable closeFunction) { - this.closeFunction = closeFunction; - } - - @Override - public Mono close() { - return Mono.defer( - () -> { - closeFunction.run(); - onClose.onComplete(); - return onClose; - }); - } - - @Override - public Mono onClose() { - return onClose; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java new file mode 100644 index 000000000..08b8b2fb7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -0,0 +1,194 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import reactor.util.annotation.Nullable; + +/** + * An implementation of {@link Payload}. This implementation is not thread-safe, and hence + * any method can not be invoked concurrently. + */ +public final class DefaultPayload implements Payload { + public static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); + + private final ByteBuffer data; + private final ByteBuffer metadata; + + private DefaultPayload(ByteBuffer data, @Nullable ByteBuffer metadata) { + this.data = data; + this.metadata = metadata; + } + + /** + * Static factory method for a text payload. Mainly looks better than "new DefaultPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(CharSequence data) { + return create(StandardCharsets.UTF_8.encode(CharBuffer.wrap(data)), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new DefaultPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(CharSequence data, @Nullable CharSequence metadata) { + return create( + StandardCharsets.UTF_8.encode(CharBuffer.wrap(data)), + metadata == null ? null : StandardCharsets.UTF_8.encode(CharBuffer.wrap(metadata))); + } + + public static Payload create(CharSequence data, Charset dataCharset) { + return create(dataCharset.encode(CharBuffer.wrap(data)), null); + } + + public static Payload create( + CharSequence data, + Charset dataCharset, + @Nullable CharSequence metadata, + Charset metadataCharset) { + return create( + dataCharset.encode(CharBuffer.wrap(data)), + metadata == null ? null : metadataCharset.encode(CharBuffer.wrap(metadata))); + } + + public static Payload create(byte[] data) { + return create(ByteBuffer.wrap(data), null); + } + + public static Payload create(byte[] data, @Nullable byte[] metadata) { + return create(ByteBuffer.wrap(data), metadata == null ? null : ByteBuffer.wrap(metadata)); + } + + public static Payload create(ByteBuffer data) { + return create(data, null); + } + + public static Payload create(ByteBuffer data, @Nullable ByteBuffer metadata) { + return new DefaultPayload(data, metadata); + } + + public static Payload create(ByteBuf data) { + return create(data, null); + } + + public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { + try { + return create(toBytes(data), metadata != null ? toBytes(metadata) : null); + } finally { + data.release(); + if (metadata != null) { + metadata.release(); + } + } + } + + public static Payload create(Payload payload) { + return create( + toBytes(payload.data()), payload.hasMetadata() ? toBytes(payload.metadata()) : null); + } + + private static byte[] toBytes(ByteBuf byteBuf) { + byte[] bytes = new byte[byteBuf.readableBytes()]; + byteBuf.markReaderIndex(); + byteBuf.readBytes(bytes); + byteBuf.resetReaderIndex(); + return bytes; + } + + @Override + public boolean hasMetadata() { + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + return metadata == null ? Unpooled.EMPTY_BUFFER : Unpooled.wrappedBuffer(metadata); + } + + @Override + public ByteBuf sliceData() { + return Unpooled.wrappedBuffer(data); + } + + @Override + public ByteBuffer getMetadata() { + return metadata == null ? DefaultPayload.EMPTY_BUFFER : metadata.duplicate(); + } + + @Override + public ByteBuffer getData() { + return data.duplicate(); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public DefaultPayload retain() { + return this; + } + + @Override + public DefaultPayload retain(int increment) { + return this; + } + + @Override + public DefaultPayload touch() { + return this; + } + + @Override + public DefaultPayload touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java new file mode 100644 index 000000000..99df97d70 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java @@ -0,0 +1,87 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; + +public class EmptyPayload implements Payload { + public static final EmptyPayload INSTANCE = new EmptyPayload(); + + private EmptyPayload() {} + + @Override + public boolean hasMetadata() { + return false; + } + + @Override + public ByteBuf sliceMetadata() { + return Unpooled.EMPTY_BUFFER; + } + + @Override + public ByteBuf sliceData() { + return Unpooled.EMPTY_BUFFER; + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public EmptyPayload retain() { + return this; + } + + @Override + public EmptyPayload retain(int increment) { + return this; + } + + @Override + public EmptyPayload touch() { + return this; + } + + @Override + public EmptyPayload touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/ExceptionUtil.java b/rsocket-core/src/main/java/io/rsocket/util/ExceptionUtil.java deleted file mode 100644 index 140435004..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/ExceptionUtil.java +++ /dev/null @@ -1,11 +0,0 @@ -package io.rsocket.util; - -public class ExceptionUtil { - public static T noStacktrace(T ex) { - ex.setStackTrace( - new StackTraceElement[] { - new StackTraceElement(ex.getClass().getName(), "", null, -1) - }); - return ex; - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java new file mode 100644 index 000000000..3ff720447 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java @@ -0,0 +1,164 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import java.util.Objects; + +public final class NumberUtils { + + /** The size of a medium in {@code byte}s. */ + public static final int MEDIUM_BYTES = 3; + + private static final int UNSIGNED_BYTE_SIZE = 8; + + private static final int UNSIGNED_BYTE_MAX_VALUE = (1 << UNSIGNED_BYTE_SIZE) - 1; + + private static final int UNSIGNED_MEDIUM_SIZE = 24; + + private static final int UNSIGNED_MEDIUM_MAX_VALUE = (1 << UNSIGNED_MEDIUM_SIZE) - 1; + + private static final int UNSIGNED_SHORT_SIZE = 16; + + private static final int UNSIGNED_SHORT_MAX_VALUE = (1 << UNSIGNED_SHORT_SIZE) - 1; + + private NumberUtils() {} + + /** + * Requires that an {@code int} is greater than or equal to zero. + * + * @param i the {@code int} to test + * @param message detail message to be used in the event that a {@link IllegalArgumentException} + * is thrown + * @return the {@code int} if greater than or equal to zero + * @throws IllegalArgumentException if {@code i} is less than zero + */ + public static int requireNonNegative(int i, String message) { + Objects.requireNonNull(message, "message must not be null"); + + if (i < 0) { + throw new IllegalArgumentException(message); + } + + return i; + } + + /** + * Requires that a {@code long} is greater than zero. + * + * @param l the {@code long} to test + * @param message detail message to be used in the event that a {@link IllegalArgumentException} + * is thrown + * @return the {@code long} if greater than zero + * @throws IllegalArgumentException if {@code l} is less than or equal to zero + */ + public static long requirePositive(long l, String message) { + Objects.requireNonNull(message, "message must not be null"); + + if (l <= 0) { + throw new IllegalArgumentException(message); + } + + return l; + } + + /** + * Requires that an {@code int} is greater than zero. + * + * @param i the {@code int} to test + * @param message detail message to be used in the event that a {@link IllegalArgumentException} + * is thrown + * @return the {@code int} if greater than zero + * @throws IllegalArgumentException if {@code i} is less than or equal to zero + */ + public static int requirePositive(int i, String message) { + Objects.requireNonNull(message, "message must not be null"); + + if (i <= 0) { + throw new IllegalArgumentException(message); + } + + return i; + } + + /** + * Requires that an {@code int} can be represented as an unsigned {@code byte}. + * + * @param i the {@code int} to test + * @return the {@code int} if it can be represented as an unsigned {@code byte} + * @throws IllegalArgumentException if {@code i} cannot be represented as an unsigned {@code byte} + */ + public static int requireUnsignedByte(int i) { + if (i > UNSIGNED_BYTE_MAX_VALUE) { + throw new IllegalArgumentException( + String.format("%d is larger than %d bits", i, UNSIGNED_BYTE_SIZE)); + } + + return i; + } + + /** + * Requires that an {@code int} can be represented as an unsigned {@code medium}. + * + * @param i the {@code int} to test + * @return the {@code int} if it can be represented as an unsigned {@code medium} + * @throws IllegalArgumentException if {@code i} cannot be represented as an unsigned {@code + * medium} + */ + public static int requireUnsignedMedium(int i) { + if (i > UNSIGNED_MEDIUM_MAX_VALUE) { + throw new IllegalArgumentException( + String.format("%d is larger than %d bits", i, UNSIGNED_MEDIUM_SIZE)); + } + + return i; + } + + /** + * Requires that an {@code int} can be represented as an unsigned {@code short}. + * + * @param i the {@code int} to test + * @return the {@code int} if it can be represented as an unsigned {@code short} + * @throws IllegalArgumentException if {@code i} cannot be represented as an unsigned {@code + * short} + */ + public static int requireUnsignedShort(int i) { + if (i > UNSIGNED_SHORT_MAX_VALUE) { + throw new IllegalArgumentException( + String.format("%d is larger than %d bits", i, UNSIGNED_SHORT_SIZE)); + } + + return i; + } + + /** + * Encode an unsigned medium integer on 3 bytes / 24 bits. This can be decoded directly by the + * {@link ByteBuf#readUnsignedMedium()} method. + * + * @param byteBuf the {@link ByteBuf} into which to write the bits + * @param i the medium integer to encode + * @see #requireUnsignedMedium(int) + */ + public static void encodeUnsignedMedium(ByteBuf byteBuf, int i) { + requireUnsignedMedium(i); + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(i >> 16); + byteBuf.writeByte(i >> 8); + byteBuf.writeByte(i); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/PayloadImpl.java b/rsocket-core/src/main/java/io/rsocket/util/PayloadImpl.java deleted file mode 100644 index e86dd270e..000000000 --- a/rsocket-core/src/main/java/io/rsocket/util/PayloadImpl.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.util; - -import io.rsocket.Frame; -import io.rsocket.Payload; -import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import javax.annotation.Nullable; - -/** - * An implementation of {@link Payload}. This implementation is not thread-safe, and hence - * any method can not be invoked concurrently. - */ -public class PayloadImpl implements Payload { - - public static final PayloadImpl EMPTY = - new PayloadImpl(Frame.NULL_BYTEBUFFER, Frame.NULL_BYTEBUFFER, false); - - private final ByteBuffer data; - private final ByteBuffer metadata; - private final int dataStartPosition; - private final int metadataStartPosition; - private final boolean reusable; - - public PayloadImpl(Frame frame) { - this(frame.getData(), frame.hasMetadata() ? frame.getMetadata() : null); - } - - public PayloadImpl(String data) { - this(data, Charset.defaultCharset()); - } - - public PayloadImpl(String data, @Nullable String metadata) { - this(data, StandardCharsets.UTF_8, metadata, StandardCharsets.UTF_8); - } - - public PayloadImpl(String data, Charset dataCharset) { - this(dataCharset.encode(data), null); - } - - public PayloadImpl( - String data, Charset dataCharset, @Nullable String metadata, Charset metaDataCharset) { - this(dataCharset.encode(data), metadata == null ? null : metaDataCharset.encode(metadata)); - } - - public PayloadImpl(byte[] data) { - this(ByteBuffer.wrap(data), Frame.NULL_BYTEBUFFER); - } - - public PayloadImpl(byte[] data, @Nullable byte[] metadata) { - this(ByteBuffer.wrap(data), metadata == null ? null : ByteBuffer.wrap(metadata)); - } - - public PayloadImpl(ByteBuffer data) { - this(data, Frame.NULL_BYTEBUFFER); - } - - public PayloadImpl(ByteBuffer data, @Nullable ByteBuffer metadata) { - this(data, metadata, true); - } - - public PayloadImpl(ByteBuffer data, ByteBuffer metadata, boolean reusable) { - this.data = data; - this.metadata = metadata; - this.reusable = reusable; - this.dataStartPosition = reusable ? this.data.position() : 0; - this.metadataStartPosition = (reusable && metadata != null) ? this.metadata.position() : 0; - } - - @Override - public ByteBuffer getData() { - if (reusable) { - data.position(dataStartPosition); - } - return data; - } - - @Override - public ByteBuffer getMetadata() { - if (metadata == null) { - return Frame.NULL_BYTEBUFFER; - } - if (reusable) { - metadata.position(metadataStartPosition); - } - return metadata; - } - - @Override - public boolean hasMetadata() { - return metadata != null; - } - - /** - * Static factory method for a text payload. Mainly looks better than "new PayloadImpl(data)" - * - * @param data the data of the payload. - * @return a payload. - */ - public static Payload textPayload(String data) { - return new PayloadImpl(data); - } - - /** - * Static factory method for a text payload. Mainly looks better than "new PayloadImpl(data, - * metadata)" - * - * @param data the data of the payload. - * @param metadata the metadata for the payload. - * @return a payload. - */ - public static Payload textPayload(String data, @Nullable String metadata) { - return new PayloadImpl(data, metadata); - } -} diff --git a/rsocket-core/src/main/java/io/rsocket/util/RSocketProxy.java b/rsocket-core/src/main/java/io/rsocket/util/RSocketProxy.java index 29aefa742..518b727c1 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/RSocketProxy.java +++ b/rsocket-core/src/main/java/io/rsocket/util/RSocketProxy.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.util; import io.rsocket.Payload; @@ -60,8 +61,13 @@ public double availability() { } @Override - public Mono close() { - return source.close(); + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); } @Override diff --git a/rsocket-core/src/main/java/io/rsocket/util/package-info.java b/rsocket-core/src/main/java/io/rsocket/util/package-info.java index 435eef5c1..2fac3327f 100644 --- a/rsocket-core/src/main/java/io/rsocket/util/package-info.java +++ b/rsocket-core/src/main/java/io/rsocket/util/package-info.java @@ -1,18 +1,21 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ -@javax.annotation.ParametersAreNonnullByDefault +/** Shared utility classes and {@link io.rsocket.Payload} implementations. */ +@NonNullApi package io.rsocket.util; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json new file mode 100644 index 000000000..0a3844451 --- /dev/null +++ b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json @@ -0,0 +1,130 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef", + "fields": [ + { + "name": "consumerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef", + "fields": [ + { + "name": "producerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.LinkedQueueNode" + }, + "name": "io.rsocket.internal.jctools.queues.LinkedQueueNode", + "fields": [ + { + "name": "next" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.UnsafeAccess" + }, + "name": "sun.misc.Unsafe", + "fields": [ + { + "name": "theUnsafe" + } + ], + "queriedMethods": [ + { + "name": "getAndAddLong", + "parameterTypes": [ + "java.lang.Object", + "long", + "long" + ] + }, + { + "name": "getAndSetObject", + "parameterTypes": [ + "java.lang.Object", + "long", + "java.lang.Object" + ] + } + ] + } +] \ No newline at end of file diff --git a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java deleted file mode 100644 index ecd0f1e64..000000000 --- a/rsocket-core/src/test/java/io/rsocket/AbstractSocketRule.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import java.util.concurrent.ConcurrentLinkedQueue; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; -import org.reactivestreams.Subscriber; - -public abstract class AbstractSocketRule extends ExternalResource { - - protected TestDuplexConnection connection; - protected Subscriber connectSub; - protected T socket; - protected ConcurrentLinkedQueue errors; - - @Override - public Statement apply(final Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - connection = new TestDuplexConnection(); - connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); - init(); - base.evaluate(); - } - }; - } - - protected void init() { - socket = newRSocket(); - } - - protected abstract T newRSocket(); -} diff --git a/rsocket-core/src/test/java/io/rsocket/FrameAssert.java b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java new file mode 100644 index 000000000..b5b1e2ec9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java @@ -0,0 +1,336 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBe.shouldBe; +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; +import static org.assertj.core.error.ShouldNotHave.shouldNotHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.*; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class FrameAssert extends AbstractAssert { + public static FrameAssert assertThat(@Nullable ByteBuf frame) { + return new FrameAssert(frame); + } + + private final Failures failures = Failures.instance(); + + public FrameAssert(@Nullable ByteBuf frame) { + super(frame, FrameAssert.class); + } + + public FrameAssert hasMetadata() { + assertValid(); + + if (!FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public FrameAssert hasNoMetadata() { + assertValid(); + + if (FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public FrameAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public FrameAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public FrameAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + final FrameType frameType = FrameHeaderCodec.frameType(actual); + ByteBuf content; + if (frameType == FrameType.METADATA_PUSH) { + content = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.metadata(actual); + } else { + content = PayloadFrameCodec.metadata(actual); + } + + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public FrameAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public FrameAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content; + final FrameType frameType = FrameHeaderCodec.frameType(actual); + if (!frameType.canHaveData()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have data content but frame type %n<%s> does not support data content", + actual, frameType)); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.data(actual); + } else if (frameType == FrameType.ERROR) { + content = ErrorFrameCodec.data(actual); + } else { + content = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasFragmentsFollow() { + return hasFollows(true); + } + + public FrameAssert hasNoFragmentsFollow() { + return hasFollows(false); + } + + public FrameAssert hasFollows(boolean hasFollows) { + assertValid(); + + if (FrameHeaderCodec.hasFollows(actual) != hasFollows) { + throw failures.failure( + info, + hasFollows + ? shouldHave(actual, new Condition<>("follows fragment present")) + : shouldNotHave(actual, new Condition<>("follows fragment present"))); + } + + return this; + } + + public FrameAssert typeOf(FrameType frameType) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + if (currentFrameType != frameType) { + throw failures.failure( + info, shouldBe(currentFrameType, new Condition<>("frame of type [" + frameType + "]"))); + } + + return this; + } + + public FrameAssert hasStreamId(int streamId) { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId != streamId) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting streamId:%n<%s>%n to be equal %n<%s>", currentStreamId, streamId)); + } + + return this; + } + + public FrameAssert hasStreamIdZero() { + return hasStreamId(0); + } + + public FrameAssert hasClientSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId % 2 != 1) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting Client Side StreamId %nbut was " + + (currentStreamId == 0 ? "Stream Id 0" : "Server Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasServerSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId == 0 || currentStreamId % 2 != 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n Server Side Stream Id %nbut was %n " + + (currentStreamId == 0 ? "Stream Id 0" : "Client Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasPayloadSize(int payloadLength) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + + final int currentFrameLength = + actual.readableBytes() + - FrameHeaderCodec.size() + - (FrameHeaderCodec.hasMetadata(actual) && currentFrameType.canHaveData() ? 3 : 0) + - (currentFrameType.hasInitialRequestN() ? Integer.BYTES : 0); + if (currentFrameLength != payloadLength) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n<%s> %nframe payload size to be equal to %n<%s> %nbut was %n<%s>", + actual, payloadLength, currentFrameLength)); + } + + return this; + } + + public FrameAssert hasRequestN(int n) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + long requestN; + if (currentFrameType.hasInitialRequestN()) { + requestN = RequestStreamFrameCodec.initialRequestN(actual); + } else if (currentFrameType == FrameType.REQUEST_N) { + requestN = RequestNFrameCodec.requestN(actual); + } else { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have requestN but frame type %n<%s> does not support requestN", + actual, currentFrameType)); + } + + if ((requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : requestN) != n) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have %nrequestN(<%s>) but got %nrequestN(<%s>)", + actual, n, requestN)); + } + + return this; + } + + public FrameAssert hasPayload(Payload expectedPayload) { + assertValid(); + + List failedExpectation = new ArrayList<>(); + FrameType frameType = FrameHeaderCodec.frameType(actual); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(actual); + if (expectedPayload.hasMetadata() != hasMetadata) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), hasMetadata)); + } else if (hasMetadata) { + ByteBuf metadataContent; + if (frameType == FrameType.METADATA_PUSH) { + metadataContent = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + metadataContent = RequestStreamFrameCodec.metadata(actual); + } else { + metadataContent = PayloadFrameCodec.metadata(actual); + } + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), metadataContent)) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), metadataContent)); + } + } + + ByteBuf dataContent; + if (!frameType.canHaveData() && expectedPayload.sliceData().readableBytes() > 0) { + failedExpectation.add( + String.format( + "data(%s) but frame type %n<%s> does not support data", actual, frameType)); + } else { + if (frameType.hasInitialRequestN()) { + dataContent = RequestStreamFrameCodec.data(actual); + } else { + dataContent = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), dataContent)) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", expectedPayload.sliceData(), dataContent)); + } + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given payload but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + public void hasNoLeaks() { + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + + try { + FrameHeaderCodec.frameType(actual); + } catch (Throwable t) { + throw failures.failure( + info, shouldBe(actual, new Condition<>("a valid frame, but got exception [" + t + "]"))); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/FrameTest.java b/rsocket-core/src/test/java/io/rsocket/FrameTest.java index 7850914ce..82af5f53c 100644 --- a/rsocket-core/src/test/java/io/rsocket/FrameTest.java +++ b/rsocket-core/src/test/java/io/rsocket/FrameTest.java @@ -1,16 +1,27 @@ -package io.rsocket; - -import static org.junit.Assert.assertEquals; +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ -import io.rsocket.frame.FrameHeaderFlyweight; -import io.rsocket.util.PayloadImpl; -import org.junit.Test; +package io.rsocket; public class FrameTest { - @Test + /*@Test public void testFrameToString() { - final Frame requestFrame = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl("streaming in -> 0"), 1); + final io.rsocket.Frame requestFrame = + io.rsocket.Frame.Request.from( + 1, FrameType.REQUEST_RESPONSE, DefaultPayload.create("streaming in -> 0"), 1); assertEquals( "Frame => Stream ID: 1 Type: REQUEST_RESPONSE Payload: data: \"streaming in -> 0\" ", requestFrame.toString()); @@ -18,9 +29,12 @@ public void testFrameToString() { @Test public void testFrameWithMetadataToString() { - final Frame requestFrame = - Frame.Request.from( - 1, FrameType.REQUEST_RESPONSE, new PayloadImpl("streaming in -> 0", "metadata"), 1); + final io.rsocket.Frame requestFrame = + io.rsocket.Frame.Request.from( + 1, + FrameType.REQUEST_RESPONSE, + DefaultPayload.create("streaming in -> 0", "metadata"), + 1); assertEquals( "Frame => Stream ID: 1 Type: REQUEST_RESPONSE Payload: metadata: \"metadata\" data: \"streaming in -> 0\" ", requestFrame.toString()); @@ -28,9 +42,12 @@ public void testFrameWithMetadataToString() { @Test public void testPayload() { - Frame frame = - Frame.PayloadFrame.from( - 1, FrameType.NEXT_COMPLETE, new PayloadImpl("Hello"), FrameHeaderFlyweight.FLAGS_C); + io.rsocket.Frame frame = + io.rsocket.Frame.PayloadFrame.from( + 1, + FrameType.NEXT_COMPLETE, + DefaultPayload.create("Hello"), + FrameHeaderFlyweight.FLAGS_C); frame.toString(); - } + }*/ } diff --git a/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java new file mode 100755 index 000000000..847f24722 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java @@ -0,0 +1,180 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.ByteBufRepresentation; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class PayloadAssert extends AbstractAssert { + + public static PayloadAssert assertThat(@Nullable Payload payload) { + return new PayloadAssert(payload); + } + + private final Failures failures = Failures.instance(); + + public PayloadAssert(@Nullable Payload payload) { + super(payload, PayloadAssert.class); + } + + public PayloadAssert hasMetadata() { + assertValid(); + + if (!actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public PayloadAssert hasNoMetadata() { + assertValid(); + + if (actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public PayloadAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public PayloadAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public PayloadAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + ByteBuf content = actual.sliceMetadata(); + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public PayloadAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public PayloadAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public PayloadAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content = actual.sliceData(); + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public void hasNoLeaks() { + if (!(actual instanceof DefaultPayload)) { + if (actual.refCnt() == 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was already released", + actual, actual.refCnt())); + } + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + } + + public void isReleased() { + if (actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) but " + "actual was " + "%n", + actual, actual.refCnt())); + } + } + + @Override + public PayloadAssert isEqualTo(Object expected) { + if (expected instanceof Payload) { + if (expected == actual) { + return this; + } + + Payload expectedPayload = (Payload) expected; + List failedExpectation = new ArrayList<>(); + if (expectedPayload.hasMetadata() != actual.hasMetadata()) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), actual.hasMetadata())); + } else { + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), actual.sliceMetadata())) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), actual.sliceMetadata())); + } + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), actual.sliceData())) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", + expectedPayload.sliceData(), actual.sliceData())); + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given one but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + return super.isEqualTo(expected); + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java deleted file mode 100644 index d5d824ef5..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static io.rsocket.FrameType.*; -import static io.rsocket.test.util.TestSubscriber.anyPayload; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.rsocket.exceptions.ApplicationException; -import io.rsocket.exceptions.RejectedSetupException; -import io.rsocket.frame.RequestFrameFlyweight; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.PayloadImpl; -import java.time.Duration; -import java.util.ArrayList; -import java.util.List; -import java.util.stream.Collectors; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import reactor.core.publisher.BaseSubscriber; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -public class RSocketClientTest { - - @Rule public final ClientSocketRule rule = new ClientSocketRule(); - - @Test(timeout = 2_000) - public void testKeepAlive() throws Exception { - assertThat("Unexpected frame sent.", rule.connection.awaitSend().getType(), is(KEEPALIVE)); - } - - @Test(timeout = 2_000) - public void testInvalidFrameOnStream0() { - rule.connection.addToReceivedBuffer(Frame.RequestN.from(0, 10)); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(IllegalStateException.class))); - } - - @Test(timeout = 2_000) - public void testStreamInitialN() { - Flux stream = rule.socket.requestStream(PayloadImpl.EMPTY); - - BaseSubscriber subscriber = - new BaseSubscriber() { - @Override - protected void hookOnSubscribe(Subscription subscription) { - // don't request here - // subscription.request(3); - } - }; - stream.subscribe(subscriber); - - subscriber.request(5); - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> f.getType() != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat("sent frame count", sent.size(), is(1)); - - Frame f = sent.get(0); - - assertThat("initial frame", f.getType(), is(REQUEST_STREAM)); - assertThat("initial request n", RequestFrameFlyweight.initialRequestN(f.content()), is(5)); - } - - @Test(timeout = 2_000) - public void testHandleSetupException() { - rule.connection.addToReceivedBuffer(Frame.Error.from(0, new RejectedSetupException("boom"))); - assertThat("Unexpected errors.", rule.errors, hasSize(1)); - assertThat( - "Unexpected error received.", - rule.errors, - contains(instanceOf(RejectedSetupException.class))); - } - - @Test(timeout = 2_000) - public void testHandleApplicationException() { - rule.connection.clearSendReceiveBuffers(); - Publisher response = rule.socket.requestResponse(PayloadImpl.EMPTY); - Subscriber responseSub = TestSubscriber.create(); - response.subscribe(responseSub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - Frame.Error.from(streamId, new ApplicationException("error"))); - - verify(responseSub).onError(any(ApplicationException.class)); - } - - @Test(timeout = 2_000) - public void testHandleValidFrame() { - Publisher response = rule.socket.requestResponse(PayloadImpl.EMPTY); - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - Frame.PayloadFrame.from(streamId, NEXT_COMPLETE, PayloadImpl.EMPTY)); - - verify(sub).onNext(anyPayload()); - verify(sub).onComplete(); - } - - @Test(timeout = 2_000) - public void testRequestReplyWithCancel() { - Mono response = rule.socket.requestResponse(PayloadImpl.EMPTY); - - try { - response.block(Duration.ofMillis(100)); - } catch (IllegalStateException ise) { - } - - List sent = - rule.connection - .getSent() - .stream() - .filter(f -> f.getType() != KEEPALIVE) - .collect(Collectors.toList()); - - assertThat( - "Unexpected frame sent on the connection.", sent.get(0).getType(), is(REQUEST_RESPONSE)); - assertThat("Unexpected frame sent on the connection.", sent.get(1).getType(), is(CANCEL)); - } - - @Test(timeout = 2_000) - @Ignore - public void testRequestReplyErrorOnSend() { - rule.connection.setAvailability(0); // Fails send - Mono response = rule.socket.requestResponse(PayloadImpl.EMPTY); - Subscriber responseSub = TestSubscriber.create(); - response.subscribe(responseSub); - - verify(responseSub).onError(any(RuntimeException.class)); - } - - @Test - public void testLazyRequestResponse() { - Publisher response = rule.socket.requestResponse(PayloadImpl.EMPTY); - int streamId = sendRequestResponse(response); - rule.connection.clearSendReceiveBuffers(); - int streamId2 = sendRequestResponse(response); - assertThat("Stream ID reused.", streamId2, not(equalTo(streamId))); - } - - public int sendRequestResponse(Publisher response) { - Subscriber sub = TestSubscriber.create(); - response.subscribe(sub); - int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); - rule.connection.addToReceivedBuffer( - Frame.PayloadFrame.from(streamId, NEXT_COMPLETE, PayloadImpl.EMPTY)); - verify(sub).onNext(anyPayload()); - verify(sub).onComplete(); - return streamId; - } - - public static class ClientSocketRule extends AbstractSocketRule { - @Override - protected RSocketClient newRSocket() { - return new RSocketClient( - connection, - throwable -> errors.add(throwable), - StreamIdSupplier.clientSupplier(), - Duration.ofMillis(100), - Duration.ofMillis(100), - 4); - } - - public int getStreamIdForRequestType(FrameType expectedFrameType) { - assertThat("Unexpected frames sent.", connection.getSent(), hasSize(greaterThanOrEqualTo(1))); - List framesFound = new ArrayList<>(); - for (Frame frame : connection.getSent()) { - if (frame.getType() == expectedFrameType) { - return frame.getStreamId(); - } - framesFound.add(frame.getType()); - } - throw new AssertionError( - "No frames sent with frame type: " - + expectedFrameType - + ", frames found: " - + framesFound); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java deleted file mode 100644 index bd6e18792..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketServerTest.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.anyOf; -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.is; - -import io.netty.buffer.Unpooled; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.PayloadImpl; -import java.util.Collection; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.atomic.AtomicBoolean; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.Mono; - -public class RSocketServerTest { - - @Rule public final ServerSocketRule rule = new ServerSocketRule(); - - @Test(timeout = 2000) - @Ignore - public void testHandleKeepAlive() throws Exception { - rule.connection.addToReceivedBuffer(Frame.Keepalive.from(Unpooled.EMPTY_BUFFER, true)); - Frame sent = rule.connection.awaitSend(); - assertThat("Unexpected frame sent.", sent.getType(), is(FrameType.KEEPALIVE)); - /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ - assertThat( - "Unexpected keep-alive frame respond flag.", - Frame.Keepalive.hasRespondFlag(sent), - is(false)); - } - - @Test(timeout = 2000) - @Ignore - public void testHandleResponseFrameNoError() throws Exception { - final int streamId = 4; - rule.connection.clearSendReceiveBuffers(); - - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - Collection> sendSubscribers = rule.connection.getSendSubscribers(); - assertThat("Request not sent.", sendSubscribers, hasSize(1)); - assertThat("Unexpected error.", rule.errors, is(empty())); - Subscriber sendSub = sendSubscribers.iterator().next(); - assertThat( - "Unexpected frame sent.", - rule.connection.awaitSend().getType(), - anyOf(is(FrameType.COMPLETE), is(FrameType.NEXT_COMPLETE))); - } - - @Test(timeout = 2000) - @Ignore - public void testHandlerEmitsError() throws Exception { - final int streamId = 4; - rule.sendRequest(streamId, FrameType.REQUEST_STREAM); - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat( - "Unexpected frame sent.", rule.connection.awaitSend().getType(), is(FrameType.ERROR)); - } - - @Test(timeout = 2_0000) - public void testCancel() { - final int streamId = 4; - final AtomicBoolean cancelled = new AtomicBoolean(); - rule.setAcceptingSocket( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.never().doOnCancel(() -> cancelled.set(true)); - } - }); - rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); - - assertThat("Unexpected error.", rule.errors, is(empty())); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - - rule.connection.addToReceivedBuffer(Frame.Cancel.from(streamId)); - assertThat("Unexpected frame sent.", rule.connection.getSent(), is(empty())); - assertThat("Subscription not cancelled.", cancelled.get(), is(true)); - } - - public static class ServerSocketRule extends AbstractSocketRule { - - private RSocket acceptingSocket; - - @Override - protected void init() { - acceptingSocket = - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - }; - super.init(); - } - - public void setAcceptingSocket(RSocket acceptingSocket) { - this.acceptingSocket = acceptingSocket; - connection = new TestDuplexConnection(); - connectSub = TestSubscriber.create(); - errors = new ConcurrentLinkedQueue<>(); - super.init(); - } - - @Override - protected RSocketServer newRSocket() { - return new RSocketServer(connection, acceptingSocket, throwable -> errors.add(throwable)); - } - - private void sendRequest(int streamId, FrameType frameType) { - Frame request = Frame.Request.from(streamId, frameType, PayloadImpl.EMPTY, 1); - connection.addToReceivedBuffer(request); - connection.addToReceivedBuffer(Frame.RequestN.from(streamId, 2)); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/RSocketTest.java deleted file mode 100644 index 021f75829..000000000 --- a/rsocket-core/src/test/java/io/rsocket/RSocketTest.java +++ /dev/null @@ -1,158 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static org.hamcrest.Matchers.empty; -import static org.hamcrest.Matchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; - -import io.rsocket.exceptions.ApplicationException; -import io.rsocket.test.util.LocalDuplexConnection; -import io.rsocket.test.util.TestSubscriber; -import io.rsocket.util.PayloadImpl; -import java.util.ArrayList; -import java.util.concurrent.CountDownLatch; -import org.hamcrest.MatcherAssert; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import reactor.core.publisher.DirectProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -public class RSocketTest { - - @Rule public final SocketRule rule = new SocketRule(); - - @Test(timeout = 2_000) - public void testRequestReplyNoError() { - Subscriber subscriber = TestSubscriber.create(); - rule.crs.requestResponse(new PayloadImpl("hello")).subscribe(subscriber); - verify(subscriber).onNext(TestSubscriber.anyPayload()); - verify(subscriber).onComplete(); - rule.assertNoErrors(); - } - - @Test(timeout = 2000) - @Ignore - public void testHandlerEmitsError() { - rule.setRequestAcceptor( - new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.error(new NullPointerException("Deliberate exception.")); - } - }); - Subscriber subscriber = TestSubscriber.create(); - rule.crs.requestResponse(PayloadImpl.EMPTY).subscribe(subscriber); - verify(subscriber).onError(any(ApplicationException.class)); - rule.assertNoErrors(); - } - - @Test(timeout = 2000) - public void testChannel() throws Exception { - CountDownLatch latch = new CountDownLatch(10); - Flux requests = Flux.range(0, 10).map(i -> new PayloadImpl("streaming in -> " + i)); - - Flux responses = rule.crs.requestChannel(requests); - - responses.doOnNext(p -> latch.countDown()).subscribe(); - - latch.await(); - } - - public static class SocketRule extends ExternalResource { - - private RSocketClient crs; - private RSocketServer srs; - private RSocket requestAcceptor; - DirectProcessor serverProcessor; - DirectProcessor clientProcessor; - private ArrayList clientErrors = new ArrayList<>(); - private ArrayList serverErrors = new ArrayList<>(); - - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - init(); - base.evaluate(); - } - }; - } - - protected void init() { - serverProcessor = DirectProcessor.create(); - clientProcessor = DirectProcessor.create(); - - LocalDuplexConnection serverConnection = - new LocalDuplexConnection("server", clientProcessor, serverProcessor); - LocalDuplexConnection clientConnection = - new LocalDuplexConnection("client", serverProcessor, clientProcessor); - - requestAcceptor = - null != requestAcceptor - ? requestAcceptor - : new AbstractRSocket() { - @Override - public Mono requestResponse(Payload payload) { - return Mono.just(payload); - } - - @Override - public Flux requestChannel(Publisher payloads) { - Flux.from(payloads) - .map(payload -> new PayloadImpl("server got -> [" + payload.toString() + "]")) - .subscribe(); - - return Flux.range(1, 10) - .map( - payload -> new PayloadImpl("server got -> [" + payload.toString() + "]")); - } - }; - - srs = - new RSocketServer( - serverConnection, requestAcceptor, throwable -> serverErrors.add(throwable)); - - crs = - new RSocketClient( - clientConnection, - throwable -> clientErrors.add(throwable), - StreamIdSupplier.clientSupplier()); - } - - public void setRequestAcceptor(RSocket requestAcceptor) { - this.requestAcceptor = requestAcceptor; - init(); - } - - public void assertNoErrors() { - MatcherAssert.assertThat( - "Unexpected error on the client connection.", clientErrors, is(empty())); - MatcherAssert.assertThat( - "Unexpected error on the server connection.", serverErrors, is(empty())); - } - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java new file mode 100644 index 000000000..d30f1415e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java @@ -0,0 +1,6 @@ +package io.rsocket; + +public class RaceTestConstants { + public static final int REPEATS = + Integer.parseInt(System.getProperty("rsocket.test.race.repeats", "1000")); +} diff --git a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java deleted file mode 100644 index 3025b78eb..000000000 --- a/rsocket-core/src/test/java/io/rsocket/StreamIdSupplierTest.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import org.junit.Test; - -public class StreamIdSupplierTest { - @Test - public void testClientSequence() { - StreamIdSupplier s = StreamIdSupplier.clientSupplier(); - assertEquals(1, s.nextStreamId()); - assertEquals(3, s.nextStreamId()); - assertEquals(5, s.nextStreamId()); - } - - @Test - public void testServerSequence() { - StreamIdSupplier s = StreamIdSupplier.serverSupplier(); - assertEquals(2, s.nextStreamId()); - assertEquals(4, s.nextStreamId()); - assertEquals(6, s.nextStreamId()); - } - - @Test - public void testClientIsValid() { - StreamIdSupplier s = StreamIdSupplier.clientSupplier(); - - assertFalse(s.isBeforeOrCurrent(1)); - assertFalse(s.isBeforeOrCurrent(3)); - - s.nextStreamId(); - assertTrue(s.isBeforeOrCurrent(1)); - assertFalse(s.isBeforeOrCurrent(3)); - - s.nextStreamId(); - assertTrue(s.isBeforeOrCurrent(3)); - - // negative - assertFalse(s.isBeforeOrCurrent(-1)); - // connection - assertFalse(s.isBeforeOrCurrent(0)); - // server also accepted (checked externally) - assertTrue(s.isBeforeOrCurrent(2)); - } - - @Test - public void testServerIsValid() { - StreamIdSupplier s = StreamIdSupplier.serverSupplier(); - - assertFalse(s.isBeforeOrCurrent(2)); - assertFalse(s.isBeforeOrCurrent(4)); - - s.nextStreamId(); - assertTrue(s.isBeforeOrCurrent(2)); - assertFalse(s.isBeforeOrCurrent(4)); - - s.nextStreamId(); - assertTrue(s.isBeforeOrCurrent(4)); - - // negative - assertFalse(s.isBeforeOrCurrent(-2)); - // connection - assertFalse(s.isBeforeOrCurrent(0)); - // client also accepted (checked externally) - assertTrue(s.isBeforeOrCurrent(1)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/TestScheduler.java b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java new file mode 100644 index 000000000..7bc98d45d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java @@ -0,0 +1,80 @@ +package io.rsocket; + +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.scheduler.Scheduler; +import reactor.util.concurrent.Queues; + +/** + * This is an implementation of scheduler which allows task execution on the caller thread or + * scheduling it for thread which are currently working (with "work stealing" behaviour) + */ +public final class TestScheduler implements Scheduler { + + public static final Scheduler INSTANCE = new TestScheduler(); + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(TestScheduler.class, "wip"); + + final Worker sharedWorker = new TestWorker(this); + final Queue tasks = Queues.unboundedMultiproducer().get(); + + private TestScheduler() {} + + @Override + public Disposable schedule(Runnable task) { + tasks.offer(task); + if (WIP.getAndIncrement(this) != 0) { + return Disposables.never(); + } + + int missed = 1; + + for (; ; ) { + for (; ; ) { + Runnable runnable = tasks.poll(); + + if (runnable == null) { + break; + } + + try { + runnable.run(); + } catch (Throwable t) { + Exceptions.throwIfFatal(t); + } + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return Disposables.never(); + } + } + } + + @Override + public Worker createWorker() { + return sharedWorker; + } + + static class TestWorker implements Worker { + + final TestScheduler parent; + + TestWorker(TestScheduler parent) { + this.parent = parent; + } + + @Override + public Disposable schedule(Runnable task) { + return parent.schedule(task); + } + + @Override + public void dispose() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..1db708ab5 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,294 @@ +package io.rsocket.buffer; + +import static java.util.concurrent.locks.LockSupport.parkNanos; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + final Duration awaitZeroRefCntDuration; + + final String tag; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { + this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java new file mode 100644 index 000000000..310e15b3e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -0,0 +1,76 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import java.time.Duration; +import org.reactivestreams.Subscriber; + +public abstract class AbstractSocketRule { + + protected TestDuplexConnection connection; + protected Subscriber connectSub; + protected T socket; + protected LeaksTrackingByteBufAllocator allocator; + protected int maxFrameLength = FRAME_LENGTH_MASK; + protected int maxInboundPayloadSize = Integer.MAX_VALUE; + + public void init() { + allocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(5), ""); + connectSub = TestSubscriber.create(); + doInit(); + } + + protected void doInit() { + if (connection != null) { + connection.dispose(); + } + if (socket != null) { + socket.dispose(); + } + connection = new TestDuplexConnection(allocator); + socket = newRSocket(); + } + + public void setMaxInboundPayloadSize(int maxInboundPayloadSize) { + this.maxInboundPayloadSize = maxInboundPayloadSize; + doInit(); + } + + public void setMaxFrameLength(int maxFrameLength) { + this.maxFrameLength = maxFrameLength; + doInit(); + } + + protected abstract T newRSocket(); + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public void assertHasNoLeaks() { + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java new file mode 100644 index 000000000..195df9434 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.test.util.TestDuplexConnection; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ClientServerInputMultiplexerTest { + private TestDuplexConnection source; + private ClientServerInputMultiplexer clientMultiplexer; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private ClientServerInputMultiplexer serverMultiplexer; + + @BeforeEach + public void setup() { + source = new TestDuplexConnection(allocator); + clientMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); + serverMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); + } + + @Test + public void clientSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + clientMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + clientMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isOne(); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isEqualTo(2); + } + + @Test + public void serverSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + serverMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + serverMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(1); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(2); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(3); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(4); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isEqualTo(4); + } + + private ByteBuf leaseFrame() { + return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf errorFrame(int i) { + return ErrorFrameCodec.encode(allocator, i, new Exception()); + } + + private ByteBuf keepAliveFrame() { + return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf metadataPushFrame() { + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java new file mode 100644 index 000000000..8eb5dee09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java @@ -0,0 +1,90 @@ +package io.rsocket.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.util.DefaultPayload; +import org.junit.jupiter.api.Test; + +class ConnectionSetupPayloadTest { + private static final int KEEP_ALIVE_INTERVAL = 5; + private static final int KEEP_ALIVE_MAX_LIFETIME = 500; + private static final String METADATA_TYPE = "metadata_type"; + private static final String DATA_TYPE = "data_type"; + + @Test + void testSetupPayloadWithDataMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {2, 1, 0}); + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = true; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertTrue(setupPayload.willClientHonorLease()); + assertEquals(KEEP_ALIVE_INTERVAL, setupPayload.keepAliveInterval()); + assertEquals(KEEP_ALIVE_MAX_LIFETIME, setupPayload.keepAliveMaxLifetime()); + assertEquals(METADATA_TYPE, SetupFrameCodec.metadataMimeType(frame)); + assertEquals(DATA_TYPE, SetupFrameCodec.dataMimeType(frame)); + assertTrue(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(payload.metadata(), setupPayload.metadata()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + @Test + void testSetupPayloadWithNoMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = null; + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = false; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertFalse(setupPayload.willClientHonorLease()); + assertFalse(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(0, setupPayload.metadata().readableBytes()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + @Test + void testSetupPayloadWithEmptyMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = Unpooled.EMPTY_BUFFER; + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = false; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertFalse(setupPayload.willClientHonorLease()); + assertTrue(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(0, setupPayload.metadata().readableBytes()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + private static ByteBuf encodeSetupFrame(boolean leaseEnabled, Payload setupPayload) { + return SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + leaseEnabled, + KEEP_ALIVE_INTERVAL, + KEEP_ALIVE_MAX_LIFETIME, + Unpooled.EMPTY_BUFFER, + METADATA_TYPE, + DATA_TYPE, + setupPayload); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java new file mode 100644 index 000000000..84576e6ce --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -0,0 +1,760 @@ +package io.rsocket.core; +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.RSocketProxy; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; +import reactor.util.context.ContextView; +import reactor.util.retry.Retry; + +public class DefaultRSocketClientTests { + + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); + rule = new ClientSocketRule(); + rule.init(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + rule.allocator.assertHasNoLeaks(); + } + + @Test + @SuppressWarnings("unchecked") + void discardElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = Mockito.mock(ReferenceCounted.class); + Mockito.when(referenceCounted.release()).thenReturn(true); + Mockito.when(referenceCounted.refCnt()).thenReturn(1); + + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(referenceCounted); + + Mockito.verify(referenceCounted).release(); + } + + static Stream interactions() { + return Stream.of( + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.fireAndForget(Mono.fromDirect(payload)), + FrameType.REQUEST_FNF), + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.requestResponse(Mono.fromDirect(payload)), + FrameType.REQUEST_RESPONSE), + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.requestStream(Mono.fromDirect(payload)), + FrameType.REQUEST_STREAM), + Arguments.of( + (BiFunction, Publisher>) + RSocketClient::requestChannel, + FrameType.REQUEST_CHANNEL), + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.metadataPush(Mono.fromDirect(payload)), + FrameType.METADATA_PUSH)); + } + + @ParameterizedTest + @MethodSource("interactions") + public void shouldSentFrameOnResolution( + BiFunction, Publisher> request, FrameType requestType) { + Payload payload = ByteBufPayload.create("test", "testMetadata"); + TestPublisher testPublisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + Publisher publisher = request.apply(rule.client, testPublisher); + + StepVerifier.create(publisher) + .expectSubscription() + .then(() -> Assertions.assertThat(rule.connection.getSent()).isEmpty()) + .then( + () -> { + if (requestType != FrameType.REQUEST_CHANNEL) { + testPublisher.next(payload); + } + }) + .then(() -> rule.delayer.run()) + .then( + () -> { + if (requestType == FrameType.REQUEST_CHANNEL) { + testPublisher.next(payload); + } + }) + .then(testPublisher::complete) + .then( + () -> { + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(rule.connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + } + }) + .then( + () -> { + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeComplete(rule.allocator, 1)); + } + }) + .expectComplete() + .verify(Duration.ofMillis(1000)); + + rule.allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfOnNextAndCancel( + BiFunction, Publisher> request, FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + rule.init(); + Payload payload = ByteBufPayload.create("test", "testMetadata"); + TestPublisher testPublisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + Publisher publisher = request.apply(rule.client, testPublisher); + publisher.subscribe(assertSubscriber); + + testPublisher.assertWasNotRequested(); + + assertSubscriber.request(1); + + testPublisher.assertWasRequested(); + testPublisher.assertMaxRequested(1); + testPublisher.assertMinRequested(1); + + RaceTestUtils.race( + () -> { + testPublisher.next(payload); + rule.delayer.run(); + }, + assertSubscriber::cancel); + + Collection sent = rule.connection.getSent(); + if (sent.size() == 1) { + Assertions.assertThat(sent) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } else if (sent.size() == 2) { + Assertions.assertThat(sent) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + Assertions.assertThat(sent) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.CANCEL)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent).isEmpty(); + } + + rule.allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfRequestAndCancel( + BiFunction, Publisher> request, FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + rule.init(); + ByteBuf dataBuffer = rule.allocator.buffer(); + dataBuffer.writeCharSequence("test", CharsetUtil.UTF_8); + + ByteBuf metadataBuffer = rule.allocator.buffer(); + metadataBuffer.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + + Payload payload = ByteBufPayload.create(dataBuffer, metadataBuffer); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + Publisher publisher = request.apply(rule.client, Mono.just(payload)); + publisher.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + assertSubscriber.request(1); + rule.delayer.run(); + }, + assertSubscriber::cancel); + + Collection sent = rule.connection.getSent(); + if (sent.size() == 1) { + Assertions.assertThat(sent) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } else if (sent.size() == 2) { + Assertions.assertThat(sent) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + Assertions.assertThat(sent) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.CANCEL)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent).isEmpty(); + } + + rule.allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldPropagateDownstreamContext( + BiFunction, Publisher> request, FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); + + ByteBuf dataBuffer = rule.allocator.buffer(); + dataBuffer.writeCharSequence("test", CharsetUtil.UTF_8); + + ByteBuf metadataBuffer = rule.allocator.buffer(); + metadataBuffer.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + + Payload payload = ByteBufPayload.create(dataBuffer, metadataBuffer); + AssertSubscriber assertSubscriber = new AssertSubscriber(Context.of("test", "test")); + + ContextView[] receivedContext = new Context[1]; + Publisher publisher = + request.apply( + rule.client, + Mono.just(payload) + .mergeWith( + Mono.deferContextual( + c -> { + receivedContext[0] = c; + return Mono.empty(); + }) + .then(Mono.empty()))); + publisher.subscribe(assertSubscriber); + + rule.delayer.run(); + + Collection sent = rule.connection.getSent(); + if (sent.size() == 1) { + Assertions.assertThat(sent) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } else if (sent.size() == 2) { + Assertions.assertThat(sent) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + Assertions.assertThat(sent) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.CANCEL)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent).isEmpty(); + } + + Assertions.assertThat(receivedContext) + .hasSize(1) + .allSatisfy( + c -> + Assertions.assertThat( + c.stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))) + .containsKeys("test", DefaultRSocketClient.ON_DISCARD_KEY)); + + rule.allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldSupportMultiSubscriptionOnTheSameInteractionPublisher( + BiFunction, Publisher> request, FrameType requestType) { + AtomicBoolean once1 = new AtomicBoolean(); + AtomicBoolean once2 = new AtomicBoolean(); + Mono source = + Mono.fromCallable( + () -> { + if (!once1.getAndSet(true)) { + throw new IllegalStateException("test"); + } + return ByteBufPayload.create("test", "testMetadata"); + }) + .doFinally( + st -> { + rule.delayer.run(); + if (requestType != FrameType.METADATA_PUSH + && requestType != FrameType.REQUEST_FNF) { + if (st != SignalType.ON_ERROR) { + if (!once2.getAndSet(true)) { + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode( + rule.allocator, 1, new IllegalStateException("test"))); + } else { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeComplete(rule.allocator, 3)); + } + } + } + }); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + Publisher publisher = request.apply(rule.client, source); + if (publisher instanceof Mono) { + ((Mono) publisher) + .retryWhen(Retry.backoff(3, Duration.ofMillis(100))) + .subscribe(assertSubscriber); + } else { + ((Flux) publisher) + .retryWhen(Retry.backoff(3, Duration.ofMillis(100))) + .subscribe(assertSubscriber); + } + + assertSubscriber.request(1); + + if (requestType == FrameType.REQUEST_CHANNEL) { + rule.delayer.run(); + } + + assertSubscriber.await(Duration.ofSeconds(10)).assertComplete(); + + if (requestType == FrameType.REQUEST_CHANNEL) { + ArrayList sent = new ArrayList<>(rule.connection.getSent()); + Assertions.assertThat(sent).hasSize(4); + for (int i = 0; i < sent.size(); i++) { + if (i % 2 == 0) { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } + } + } else { + Collection sent = rule.connection.getSent(); + Assertions.assertThat(sent) + .hasSize( + requestType == FrameType.REQUEST_FNF || requestType == FrameType.METADATA_PUSH + ? 1 + : 2) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldBeAbleToResolveOriginalSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + rule.client.source().subscribe(assertSubscriber); + + assertSubscriber.assertNotTerminated(); + + rule.delayer.run(); + + assertSubscriber.request(1); + + assertSubscriber.assertTerminated().assertValueCount(1); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.source().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertValueCount(1); + + Assertions.assertThat(assertSubscriber1.values()).isEqualTo(assertSubscriber.values()); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldDisposeOriginalSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.source().subscribe(assertSubscriber1); + + assertSubscriber1 + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Disposed"); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldReceiveOnCloseNotificationOnDisposeOriginalSource() { + Sinks.Empty onCloseDelayer = Sinks.empty(); + ClientSocketRule rule = + new ClientSocketRule() { + @Override + protected RSocket newRSocket() { + return new RSocketProxy(super.newRSocket()) { + @Override + public Mono onClose() { + return super.onClose().and(onCloseDelayer.asMono()); + } + }; + } + }; + rule.init(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber onCloseSubscriber = AssertSubscriber.create(); + + rule.client.onClose().subscribe(onCloseSubscriber); + onCloseSubscriber.assertNotTerminated(); + + onCloseDelayer.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldResolveOnStartSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldNotStartIfAlreadyDisposed() { + Assertions.assertThat(rule.client.connect()).isTrue(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.delayer.run(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldBeRestartedIfSourceWasClosed() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + AssertSubscriber terminateSubscriber = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.client.onClose().subscribe(terminateSubscriber); + + rule.delayer.run(); + + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.socket.dispose(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + terminateSubscriber.assertNotTerminated(); + Assertions.assertThat(rule.client.isDisposed()).isFalse(); + + rule.connection = new TestDuplexConnection(rule.allocator); + rule.socket = rule.newRSocket(); + rule.producer = Sinks.one(); + + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber2); + + rule.delayer.run(); + + assertSubscriber2.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + terminateSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldDisposeOriginalSourceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + + rule.init(); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + + RaceTestUtils.race(rule.delayer, () -> rule.client.dispose()); + + assertSubscriber.assertTerminated(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.source().subscribe(assertSubscriber1); + + assertSubscriber1 + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Disposed"); + + ByteBuf buf; + while ((buf = rule.connection.pollFrame()) != null) { + FrameAssert.assertThat(buf).hasStreamIdZero().hasData("Disposed").hasNoLeaks(); + } + + rule.allocator.assertHasNoLeaks(); + } + } + + @Test + public void shouldStartOriginalSourceOnceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + + rule.init(); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + RaceTestUtils.race( + () -> rule.client.source().subscribe(assertSubscriber), () -> rule.client.connect()); + + Assertions.assertThat(rule.producer.currentSubscriberCount()).isOne(); + + rule.delayer.run(); + + assertSubscriber.assertTerminated(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + assertSubscriber1.assertTerminated().assertComplete(); + + rule.allocator.assertHasNoLeaks(); + } + } + + public static class ClientSocketRule extends AbstractSocketRule { + + protected RSocketClient client; + protected Runnable delayer; + protected Sinks.One producer; + + protected Sinks.Empty thisClosedSink; + + @Override + protected void doInit() { + super.doInit(); + delayer = () -> producer.tryEmitValue(socket); + producer = Sinks.one(); + client = + new DefaultRSocketClient( + Mono.defer( + () -> + producer + .asMono() + .doOnCancel(() -> socket.dispose()) + .doOnDiscard(Disposable.class, Disposable::dispose))); + } + + @Override + protected RSocket newRSocket() { + this.thisClosedSink = Sinks.empty(); + return new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + maxFrameLength, + maxInboundPayloadSize, + Integer.MAX_VALUE, + Integer.MAX_VALUE, + null, + __ -> null, + null, + thisClosedSink, + thisClosedSink.asMono()); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java new file mode 100644 index 000000000..f5422a4bf --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java @@ -0,0 +1,448 @@ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class FireAndForgetRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameShouldBeSentOnSubscription(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Payload payload = genericPayload(activeStreams.getAllocator()); + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + final ByteBuf frame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectNothing(); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends all fragments as a separate + * frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameFragmentsShouldBeSentOnSubscription( + Consumer monoConsumer) { + final int mtu = 64; + final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + // should not add anything to map + streamManager.assertNoActiveStreams(); + stateAssert.isTerminated(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> frameSent() { + return Stream.of( + (s) -> StepVerifier.create(s).expectSubscription().expectComplete().verify(), + FireAndForgetRequesterMono::block); + } + + /** + * RefCnt validation test. Should send error if RefCnt is incorrect and frame has already been + * released Note: ONCE state should be 0 + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject(FrameType.REQUEST_FNF, new IllegalReferenceCountException("refCnt: 0")) + .expectNothing(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Check that proper payload size validation is enabled so in case payload fragmentation is + * disabled we will not send anything bigger that 16MB (see specification for MAX frame size) + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject( + FrameType.REQUEST_FNF, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK))) + .expectNothing(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that frame will not be sent if we dont have availability for that. Options: 1. RSocket + * disposed / Connection Error, so all racing on existing interactions should be terminated as + * well 2. RSocket tries to use lease and end-ups with no available leases + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RuntimeException exception = new RuntimeException("test"); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(exception, testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + final Payload payload = genericPayload(allocator); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor.expectOnReject(FrameType.REQUEST_FNF, exception).expectNothing(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + /** Ensures single subscription happens in case of racing */ + @Test + public void shouldSubscribeExactlyOnce1() { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + for (int i = 1; i < 50000; i += 2) { + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AtomicReference atomicReference = new AtomicReference<>(); + fireAndForgetRequesterMono.subscribe(null, atomicReference::set); + Throwable throwable = atomicReference.get(); + if (throwable != null) { + throw Exceptions.propagate(throwable); + } + }, + fireAndForgetRequesterMono::block)) + .matches( + t -> { + Assertions.assertThat(t) + .hasMessageContaining("FireAndForgetMono allows only a single Subscriber"); + return true; + }); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(i) + .hasNoLeaks(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, testRequesterResponderSupport); + + Assertions.assertThat(Scannable.from(fireAndForgetRequesterMono).name()) + .isEqualTo("source(FireAndForgetMono)"); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java new file mode 100644 index 000000000..5be59235c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -0,0 +1,420 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.keepalive.KeepAliveHandler.DefaultKeepAliveHandler; +import static io.rsocket.keepalive.KeepAliveHandler.ResumableKeepAliveHandler; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.RSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; +import io.rsocket.test.util.TestDuplexConnection; +import java.time.Duration; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class KeepAliveTest { + private static final int KEEP_ALIVE_INTERVAL = 100; + private static final int KEEP_ALIVE_TIMEOUT = 1000; + private static final int RESUMABLE_KEEP_ALIVE_TIMEOUT = 200; + + VirtualTimeScheduler virtualTimeScheduler; + + @BeforeEach + public void setUp() { + virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + } + + @AfterEach + public void tearDown() { + VirtualTimeScheduler.reset(); + } + + static RSocketState requester(int tickPeriod, int timeout) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); + Sinks.Empty empty = Sinks.empty(); + RSocketRequester rSocket = + new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + tickPeriod, + timeout, + new DefaultKeepAliveHandler(), + r -> null, + null, + empty, + empty.asMono()); + return new RSocketState(rSocket, allocator, connection, empty); + } + + static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); + ResumableDuplexConnection resumableConnection = + new ResumableDuplexConnection( + "test", + Unpooled.EMPTY_BUFFER, + connection, + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 10_000)); + Sinks.Empty onClose = Sinks.empty(); + + RSocketRequester rSocket = + new RSocketRequester( + resumableConnection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + tickPeriod, + timeout, + new ResumableKeepAliveHandler( + resumableConnection, + Mockito.mock(RSocketSession.class), + Mockito.mock(ResumeStateHolder.class)), + __ -> null, + null, + onClose, + onClose.asMono()); + return new ResumableRSocketState(rSocket, connection, resumableConnection, onClose, allocator); + } + + @Test + void rSocketNotDisposedOnPresentKeepAlives() { + RSocketState requesterState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + TestDuplexConnection connection = requesterState.connection(); + + Disposable disposable = + Flux.interval(Duration.ofMillis(KEEP_ALIVE_INTERVAL)) + .subscribe( + n -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + requesterState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_TIMEOUT * 2)); + + RSocket rSocket = requesterState.rSocket(); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + + disposable.dispose(); + + requesterState.connection.dispose(); + requesterState.rSocket.dispose(); + + Assertions.assertThat(requesterState.connection.getSent()).allMatch(ByteBuf::release); + + requesterState.allocator.assertHasNoLeaks(); + } + + @Test + void noKeepAlivesSentAfterRSocketDispose() { + RSocketState requesterState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + requesterState.rSocket().dispose(); + + Duration duration = Duration.ofMillis(500); + + virtualTimeScheduler.advanceTimeBy(duration); + + FrameAssert.assertThat(requesterState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(requesterState.connection.pollFrame()).isNull(); + requesterState.allocator.assertHasNoLeaks(); + } + + @Test + void rSocketDisposedOnMissingKeepAlives() { + RSocketState requesterState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + RSocket rSocket = requesterState.rSocket(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_TIMEOUT * 2)); + + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + rSocket + .onClose() + .as(StepVerifier::create) + .expectError(ConnectionErrorException.class) + .verify(Duration.ofMillis(100)); + + Assertions.assertThat(requesterState.connection.getSent()).allMatch(ByteBuf::release); + + requesterState.allocator.assertHasNoLeaks(); + } + + @Test + void clientRequesterSendsKeepAlives() { + RSocketState RSocketState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + TestDuplexConnection connection = RSocketState.connection(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + + RSocketState.rSocket.dispose(); + FrameAssert.assertThat(connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); + RSocketState.connection.dispose(); + + RSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void requesterRespondsToKeepAlives() { + RSocketState rSocketState = requester(100_000, 100_000); + TestDuplexConnection connection = rSocketState.connection(); + Duration duration = Duration.ofMillis(100); + Mono.delay(duration) + .subscribe( + l -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + rSocketState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); + + virtualTimeScheduler.advanceTimeBy(duration); + FrameAssert.assertThat(connection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .matches(this::keepAliveFrameWithoutRespondFlag); + + rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + rSocketState.connection.dispose(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRequesterNoKeepAlivesAfterDisconnect() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + TestDuplexConnection testConnection = rSocketState.connection(); + ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); + + resumableDuplexConnection.disconnect(); + + Duration duration = Duration.ofMillis(KEEP_ALIVE_INTERVAL * 5); + virtualTimeScheduler.advanceTimeBy(duration); + Assertions.assertThat(testConnection.pollFrame()).isNull(); + + rSocketState.rSocket.dispose(); + rSocketState.connection.dispose(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRequesterKeepAlivesAfterReconnect() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); + resumableDuplexConnection.disconnect(); + TestDuplexConnection newTestConnection = new TestDuplexConnection(rSocketState.alloc()); + resumableDuplexConnection.connect(newTestConnection); + // resumableDuplexConnection.(0, 0, ignored -> Mono.empty()); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + + FrameAssert.assertThat(newTestConnection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .hasStreamIdZero() + .hasNoLeaks(); + + rSocketState.rSocket.dispose(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") // API limitations + .hasNoLeaks(); + newTestConnection.dispose(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRequesterNoKeepAlivesAfterDispose() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + rSocketState.rSocket().dispose(); + Duration duration = Duration.ofMillis(500); + StepVerifier.create(Flux.from(rSocketState.connection().getSentAsPublisher()).take(duration)) + .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + rSocketState.connection.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") + .hasNoLeaks(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRSocketsNotDisposedOnMissingKeepAlives() throws InterruptedException { + ResumableRSocketState resumableRequesterState = + resumableRequester(KEEP_ALIVE_INTERVAL, RESUMABLE_KEEP_ALIVE_TIMEOUT); + RSocket rSocket = resumableRequesterState.rSocket(); + TestDuplexConnection connection = resumableRequesterState.connection(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(500)); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + Assertions.assertThat(connection.isDisposed()).isTrue(); + + Assertions.assertThat(resumableRequesterState.connection.getSent()).allMatch(ByteBuf::release); + + resumableRequesterState.connection.dispose(); + resumableRequesterState.rSocket.dispose(); + + resumableRequesterState.allocator.assertHasNoLeaks(); + } + + private boolean keepAliveFrame(ByteBuf frame) { + return FrameHeaderCodec.frameType(frame) == FrameType.KEEPALIVE; + } + + private boolean keepAliveFrameWithRespondFlag(ByteBuf frame) { + return keepAliveFrame(frame) && KeepAliveFrameCodec.respondFlag(frame) && frame.release(); + } + + private boolean keepAliveFrameWithoutRespondFlag(ByteBuf frame) { + return keepAliveFrame(frame) && !KeepAliveFrameCodec.respondFlag(frame) && frame.release(); + } + + static class RSocketState { + private final RSocket rSocket; + private final TestDuplexConnection connection; + private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; + + public RSocketState( + RSocket rSocket, + LeaksTrackingByteBufAllocator allocator, + TestDuplexConnection connection, + Sinks.Empty onClose) { + this.rSocket = rSocket; + this.connection = connection; + this.allocator = allocator; + this.onClose = onClose; + } + + public TestDuplexConnection connection() { + return connection; + } + + public RSocket rSocket() { + return rSocket; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + } + + static class ResumableRSocketState { + private final RSocket rSocket; + private final TestDuplexConnection connection; + private final ResumableDuplexConnection resumableDuplexConnection; + private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; + + public ResumableRSocketState( + RSocket rSocket, + TestDuplexConnection connection, + ResumableDuplexConnection resumableDuplexConnection, + Sinks.Empty onClose, + LeaksTrackingByteBufAllocator allocator) { + this.rSocket = rSocket; + this.connection = connection; + this.resumableDuplexConnection = resumableDuplexConnection; + this.onClose = onClose; + this.allocator = allocator; + } + + public TestDuplexConnection connection() { + return connection; + } + + public ResumableDuplexConnection resumableDuplexConnection() { + return resumableDuplexConnection; + } + + public RSocket rSocket() { + return rSocket; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java new file mode 100644 index 000000000..707d42afe --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -0,0 +1,142 @@ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class PayloadValidationUtilsTest { + + @Test + void shouldBeValidFrameWithNoFragmentation() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation1() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = + new byte[maxFrameLength - FRAME_LENGTH_SIZE - Integer.BYTES - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size() + 1]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation0() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[maxFrameLength / 2]; + byte[] data = + new byte + [(maxFrameLength / 2 + 1) + - FRAME_LENGTH_SIZE + - FrameHeaderCodec.size() + - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation1() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation2() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation3() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation4() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java new file mode 100644 index 000000000..7cf12a81e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; + +public class RSocketConnectorTest { + + @ParameterizedTest + @ValueSource(strings = {"KEEPALIVE", "REQUEST_RESPONSE"}) + public void unexpectedFramesBeforeResumeOKFrame(String frameType) { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create() + .resume(new Resume().retry(Retry.indefinitely())) + .connect(transport) + .block(); + + final TestDuplexConnection duplexConnection = transport.testConnection(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + FrameAssert.assertThat(duplexConnection.pollFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + + duplexConnection.dispose(); + + final TestDuplexConnection duplexConnection2 = transport.testConnection(); + + final ByteBuf frame; + switch (frameType) { + case "KEEPALIVE": + frame = + KeepAliveFrameCodec.encode(duplexConnection2.alloc(), false, 1, Unpooled.EMPTY_BUFFER); + break; + case "REQUEST_RESPONSE": + default: + frame = + RequestResponseFrameCodec.encode( + duplexConnection2.alloc(), 2, false, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + } + duplexConnection2.addToReceivedBuffer(frame); + + StepVerifier.create(duplexConnection2.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .typeOf(FrameType.RESUME) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("RESUME_OK frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresThatSetupPayloadCanBeRetained() { + AtomicReference retainedSetupPayload = new AtomicReference<>(); + TestClientTransport transport = new TestClientTransport(); + + ByteBuf data = transport.alloc().buffer(); + + data.writeCharSequence("data", CharsetUtil.UTF_8); + + RSocketConnector.create() + .setupPayload(ByteBufPayload.create(data)) + .acceptor( + (setup, sendingSocket) -> { + retainedSetupPayload.set(setup.retain()); + return Mono.just(new RSocket() {}); + }) + .connect(transport) + .block(); + + assertThat(transport.testConnection().getSent()) + .hasSize(1) + .first() + .matches( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return !payload.hasMetadata() && payload.getDataUtf8().equals("data"); + }) + .matches(buf -> buf.refCnt() == 2) + .matches( + buf -> { + buf.release(); + return buf.refCnt() == 1; + }); + + ConnectionSetupPayload setup = retainedSetupPayload.get(); + String dataUtf8 = setup.getDataUtf8(); + assertThat("data".equals(dataUtf8) && setup.release()).isTrue(); + assertThat(setup.refCnt()).isZero(); + + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresThatMonoFromRSocketConnectorCanBeUsedForMultipleSubscriptions() { + Payload setupPayload = ByteBufPayload.create("TestData", "TestMetadata"); + assertThat(setupPayload.refCnt()).isOne(); + + // Keep the data and metadata around so we can try changing them independently + ByteBuf dataBuf = setupPayload.data(); + ByteBuf metadataBuf = setupPayload.metadata(); + dataBuf.retain(); + metadataBuf.retain(); + + TestClientTransport testClientTransport = new TestClientTransport(); + Mono connectionMono = + RSocketConnector.create().setupPayload(setupPayload).connect(testClientTransport); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData") + && payload.getMetadataUtf8().equals("TestMetadata"); + }) + .allMatch(ReferenceCounted::release); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + // Changing the original data and metadata should not impact the SetupPayload + dataBuf.writerIndex(dataBuf.readerIndex()); + dataBuf.writeChar('d'); + dataBuf.release(); + + metadataBuf.writerIndex(metadataBuf.readerIndex()); + metadataBuf.writeChar('m'); + metadataBuf.release(); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData") + && payload.getMetadataUtf8().equals("TestMetadata"); + }) + .allMatch( + byteBuf -> { + System.out.println("calling release " + byteBuf.refCnt()); + return byteBuf.release(); + }); + assertThat(setupPayload.refCnt()).isZero(); + + testClientTransport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresThatSetupPayloadProvidedAsMonoIsReleased() { + List saved = new ArrayList<>(); + AtomicLong subscriptions = new AtomicLong(); + Mono setupPayloadMono = + Mono.create( + sink -> { + final long subscriptionN = subscriptions.getAndIncrement(); + Payload payload = + ByteBufPayload.create("TestData" + subscriptionN, "TestMetadata" + subscriptionN); + saved.add(payload); + sink.success(payload); + }); + + TestClientTransport testClientTransport = new TestClientTransport(); + Mono connectionMono = + RSocketConnector.create().setupPayload(setupPayloadMono).connect(testClientTransport); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData0") + && payload.getMetadataUtf8().equals("TestMetadata0"); + }) + .allMatch(ReferenceCounted::release); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData1") + && payload.getMetadataUtf8().equals("TestMetadata1"); + }) + .allMatch(ReferenceCounted::release); + + assertThat(saved) + .as("Metadata and data were consumed and released as slices") + .allMatch( + payload -> + payload.refCnt() == 1 + && payload.data().refCnt() == 0 + && payload.metadata().refCnt() == 0); + + testClientTransport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeLessThenMtu() { + RSocketConnector.create() + .fragment(128) + .connect(new TestClientTransport().withMaxFrameLength(64)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maximumTransmissionUnit[128] exceeds configured maxFrameLength[64]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPayloadSize() { + RSocketConnector.create() + .maxInboundPayloadSize(128) + .connect(new TestClientTransport().withMaxFrameLength(256)) + .as(StepVerifier::create) + .expectErrorMessage("Configured maxFrameLength[256] exceeds maxPayloadSize[128]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPossibleFrameLength() { + RSocketConnector.create() + .connect(new TestClientTransport().withMaxFrameLength(Integer.MAX_VALUE)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maxFrameLength[" + + Integer.MAX_VALUE + + "] " + + "exceeds maxFrameLength limit " + + FRAME_LENGTH_MASK) + .verify(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java new file mode 100644 index 000000000..a461833d3 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -0,0 +1,724 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.LEASE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.SETUP; +import static org.assertj.core.data.Offset.offset; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.Lease; +import io.rsocket.lease.MissingLeaseException; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestServerTransport; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collection; +import java.util.function.BiFunction; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +class RSocketLeaseTest { + private static final String TAG = "test"; + + private RSocket rSocketRequester; + private ResponderLeaseTracker responderLeaseTracker; + private LeaksTrackingByteBufAllocator byteBufAllocator; + private TestDuplexConnection connection; + private RSocketResponder rSocketResponder; + private RSocket mockRSocketHandler; + + private Sinks.Many leaseSender = Sinks.many().multicast().onBackpressureBuffer(); + private RequesterLeaseTracker requesterLeaseTracker; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @BeforeEach + void setUp() { + PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + byteBufAllocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + connection = new TestDuplexConnection(byteBufAllocator); + requesterLeaseTracker = new RequesterLeaseTracker(TAG, 0); + responderLeaseTracker = new ResponderLeaseTracker(TAG, connection, () -> leaseSender.asFlux()); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(connection, new InitializingInterceptorRegistry(), true); + rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + requesterLeaseTracker, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + + mockRSocketHandler = mock(RSocket.class); + when(mockRSocketHandler.metadataPush(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Mono.empty(); + }); + when(mockRSocketHandler.fireAndForget(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Mono.empty(); + }); + when(mockRSocketHandler.requestResponse(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Mono.empty(); + }); + when(mockRSocketHandler.requestStream(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Flux.empty(); + }); + when(mockRSocketHandler.requestChannel(any())) + .then( + a -> { + Publisher payloadPublisher = a.getArgument(0); + return Flux.from(payloadPublisher) + .doOnNext(ReferenceCounted::release) + .transform( + Operators.lift( + (__, actual) -> + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + actual.onSubscribe(this); + } + + @Override + protected void hookOnComplete() { + actual.onComplete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + actual.onError(throwable); + } + })); + }); + + rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + mockRSocketHandler, + payloadDecoder, + responderLeaseTracker, + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); + } + + @AfterEach + void tearDownAndCheckForLeaks() { + byteBufAllocator.assertHasNoLeaks(); + } + + @Test + public void serverRSocketFactoryRejectsUnsupportedLease() { + Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); + ByteBuf setupFrame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + true, + 1000, + 30_000, + "application/octet-stream", + "application/octet-stream", + payload); + + TestServerTransport transport = new TestServerTransport(); + RSocketServer.create().bind(transport).block(); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer(setupFrame); + + Collection sent = connection.getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf error = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(error)).isEqualTo(ERROR); + Assertions.assertThat(Exceptions.from(0, error).getMessage()) + .isEqualTo("lease is not supported"); + error.release(); + connection.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void clientRSocketFactorySetsLeaseFlag() { + TestClientTransport clientTransport = new TestClientTransport(); + try { + RSocketConnector.create().lease().connect(clientTransport).block(); + Collection sent = clientTransport.testConnection().getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf setup = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); + Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); + setup.release(); + } finally { + clientTransport.testConnection().dispose(); + clientTransport.alloc().assertHasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterMissingLeaseRequestsAreRejected( + BiFunction> interaction) { + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + StepVerifier.create(interaction.apply(rSocketRequester, payload1)) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterPresentLeaseRequestsAreAccepted( + BiFunction> interaction, FrameType frameType) { + ByteBuf frame = leaseFrame(5_000, 2, Unpooled.EMPTY_BUFFER); + requesterLeaseTracker.handleLeaseFrame(frame); + + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(1.0, offset(1e-2)); + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + Flux.from(interaction.apply(rSocketRequester, payload1)) + .as(StepVerifier::create) + .then( + () -> { + if (frameType != REQUEST_FNF) { + connection.addToReceivedBuffer( + PayloadFrameCodec.encodeComplete(byteBufAllocator, 1)); + } + }) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + if (frameType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + } + + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.5, offset(1e-2)); + + Assertions.assertThat(frame.release()).isTrue(); + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"rawtypes", "unchecked"}) + void requesterDepletedAllowedLeaseRequestsAreRejected( + BiFunction> interaction, FrameType interactionType) { + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + ByteBuf leaseFrame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); + requesterLeaseTracker.handleLeaseFrame(leaseFrame); + + double initialAvailability = requesterLeaseTracker.availability(); + Publisher request = interaction.apply(rSocketRequester, payload1); + + // ensures that lease is not used until the frame is sent + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); + Assertions.assertThat(connection.getSent()).hasSize(0); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + request.subscribe(assertSubscriber); + + // if request is FNF, then request frame is sent on subscribe + // otherwise we need to make request(1) + if (interactionType != REQUEST_FNF) { + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); + Assertions.assertThat(connection.getSent()).hasSize(0); + + assertSubscriber.request(1); + } + + // ensures availability is changed and lease is used only up on frame sending + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); + + if (interactionType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + } + + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + Flux.from(interaction.apply(rSocketRequester, payload2)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(leaseFrame.release()).isTrue(); + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterExpiredLeaseRequestsAreRejected( + BiFunction> interaction) { + ByteBuf frame = leaseFrame(50, 1, Unpooled.EMPTY_BUFFER); + requesterLeaseTracker.handleLeaseFrame(frame); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + Flux.defer(() -> interaction.apply(rSocketRequester, payload1)) + .delaySubscription(Duration.ofMillis(200)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(frame.release()).isTrue(); + + byteBufAllocator.assertHasNoLeaks(); + } + + @Test + void requesterAvailabilityRespectsTransport() { + ByteBuf frame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); + try { + + requesterLeaseTracker.handleLeaseFrame(frame); + double unavailable = 0.0; + connection.setAvailability(unavailable); + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(unavailable, offset(1e-2)); + } finally { + frame.release(); + } + } + + @ParameterizedTest + @MethodSource("responderInteractions") + void responderMissingLeaseRequestsAreRejected(FrameType frameType) { + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("responderInteractions") + void responderPresentLeaseRequestsAreAccepted(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2)); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFireAndForget(1, fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } + + Assertions.assertThat(connection.getSent()) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) + .matches(ReferenceCounted::release); + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("responderInteractions") + void responderDepletedAllowedLeaseRequestsAreRejected(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 1)); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test2", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf fnfFrame2 = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(fnfFrame); + rSocketResponder.handleFrame(fnfFrame2); + fnfFrame.release(); + fnfFrame2.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf requestResponseFrame2 = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(requestResponseFrame); + rSocketResponder.handleFrame(requestResponseFrame2); + requestResponseFrame.release(); + requestResponseFrame2.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + final ByteBuf requestStreamFrame2 = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, 1, payload2); + rSocketResponder.handleFrame(requestStreamFrame); + rSocketResponder.handleFrame(requestStreamFrame2); + requestStreamFrame.release(); + requestStreamFrame2.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + final ByteBuf requestChannelFrame2 = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, true, 1, payload2); + rSocketResponder.handleFrame(requestChannelFrame); + rSocketResponder.handleFrame(requestChannelFrame2); + requestChannelFrame.release(); + requestChannelFrame2.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } + + Assertions.assertThat(connection.getSent()) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) + .matches(ReferenceCounted::release); + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(2) + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + void expiredLeaseRequestsAreRejected(BiFunction> interaction) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(50), 1)); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + Flux.from(interaction.apply(rSocketRequester, payload1)) + .delaySubscription(Duration.ofMillis(100)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) + .matches(ReferenceCounted::release); + + byteBufAllocator.assertHasNoLeaks(); + } + + @Test + void sendLease() { + ByteBuf metadata = byteBufAllocator.buffer(); + Charset utf8 = StandardCharsets.UTF_8; + String metadataContent = "test"; + metadata.writeCharSequence(metadataContent, utf8); + int ttl = 5_000; + int numberOfRequests = 2; + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2, metadata)); + + ByteBuf leaseFrame = + connection + .getSent() + .stream() + .filter(f -> FrameHeaderCodec.frameType(f) == FrameType.LEASE) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Lease frame not sent")); + + try { + Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); + Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); + Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) + .isEqualTo(metadataContent); + } finally { + leaseFrame.release(); + } + } + + // @Test + // void receiveLease() { + // Collection receivedLeases = new ArrayList<>(); + // leaseReceiver.subscribe(lease -> receivedLeases.add(lease)); + // + // ByteBuf metadata = byteBufAllocator.buffer(); + // Charset utf8 = StandardCharsets.UTF_8; + // String metadataContent = "test"; + // metadata.writeCharSequence(metadataContent, utf8); + // int ttl = 5_000; + // int numberOfRequests = 2; + // + // ByteBuf leaseFrame = leaseFrame(ttl, numberOfRequests, metadata).retain(1); + // + // connection.addToReceivedBuffer(leaseFrame); + // + // Assertions.assertThat(receivedLeases.isEmpty()).isFalse(); + // Lease receivedLease = receivedLeases.iterator().next(); + // Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); + // + // Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); + // Assertions.assertThat(receivedLease.metadata().toString(utf8)).isEqualTo(metadataContent); + // + // ReferenceCountUtil.safeRelease(leaseFrame); + // } + + ByteBuf leaseFrame(int ttl, int requests, ByteBuf metadata) { + return LeaseFrameCodec.encode(byteBufAllocator, ttl, requests, metadata); + } + + static Stream interactions() { + return Stream.of( + Arguments.of( + (BiFunction>) RSocket::fireAndForget, + FrameType.REQUEST_FNF), + Arguments.of( + (BiFunction>) RSocket::requestResponse, + FrameType.REQUEST_RESPONSE), + Arguments.of( + (BiFunction>) RSocket::requestStream, + FrameType.REQUEST_STREAM), + Arguments.of( + (BiFunction>) + (rSocket, payload) -> rSocket.requestChannel(Mono.just(payload)), + FrameType.REQUEST_CHANNEL)); + } + + static Stream responderInteractions() { + return Stream.of( + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java new file mode 100644 index 000000000..966fd65f2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -0,0 +1,203 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.transport.ClientTransport; +import java.io.UncheckedIOException; +import java.time.Duration; +import java.util.Iterator; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.Exceptions; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketReconnectTest { + + private Queue retries = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws InterruptedException { + TestClientTransport[] testClientTransport = + new TestClientTransport[] {new TestClientTransport()}; + Mono rSocketMono = + RSocketConnector.create() + .reconnect(Retry.indefinitely()) + .connect(() -> testClientTransport[0]); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket1).isEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); + testClientTransport[0] = new TestClientTransport(); + + RSocket rSocket3 = rSocketMono.block(); + RSocket rSocket4 = rSocketMono.block(); + + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + rSocket3.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); + } + + @Test + @SuppressWarnings({"rawtype"}) + public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); + Mockito.when(transport.connect()) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(transport1.connect()); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + assertThat(rSocket1).isEqualTo(rSocket2); + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + + FrameAssert.assertThat(transport1.testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + transport1.testConnection().dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + transport1.alloc().assertHasNoLeaks(); + } + + @Test + @SuppressWarnings({"rawtype"}) + public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); + Mockito.when(transport.connect()) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(transport1.connect()); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + + transport1.alloc().assertHasNoLeaks(); + } + + @Test + public void shouldBeNotBeASharedReconnectableInstanceOfRSocketMono() { + TestClientTransport transport = new TestClientTransport(); + Mono rSocketMono = RSocketConnector.connectWith(transport); + + RSocket rSocket1 = rSocketMono.block(); + TestDuplexConnection connection1 = transport.testConnection(); + + FrameAssert.assertThat(connection1.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + RSocket rSocket2 = rSocketMono.block(); + TestDuplexConnection connection2 = transport.testConnection(); + + assertThat(rSocket1).isNotEqualTo(rSocket2); + + FrameAssert.assertThat(connection2.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + connection1.dispose(); + connection2.dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + rSocket2.onClose().block(Duration.ofSeconds(1)); + transport.alloc().assertHasNoLeaks(); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertThat(retries.size()).isEqualTo(exceptions.length); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); + index++; + } + } + + Consumer onRetry() { + return context -> retries.add(context); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java new file mode 100644 index 000000000..01eb998c7 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -0,0 +1,206 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.util.RaceTestUtils; + +class RSocketRequesterSubscribersTest { + + private static final Set REQUEST_TYPES = + new HashSet<>( + Arrays.asList( + FrameType.METADATA_PUSH, + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL)); + + private LeaksTrackingByteBufAllocator allocator; + private RSocket rSocketRequester; + private TestDuplexConnection connection; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @AfterEach + void tearDownAndCheckNoLeaks() { + allocator.assertHasNoLeaks(); + } + + @BeforeEach + void setUp() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + rSocketRequester = + new RSocketRequester( + connection, + PayloadDecoder.DEFAULT, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + } + + @ParameterizedTest + @MethodSource("allInteractions") + @SuppressWarnings({"rawtypes", "unchecked"}) + void singleSubscriber(Function> interaction, FrameType requestType) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + response.subscribe(assertSubscriberA); + response.subscribe(assertSubscriberB); + + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), 1)); + } + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); + + FrameAssert.assertThat(connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + + if (requestType == FrameType.REQUEST_CHANNEL) { + FrameAssert.assertThat(connection.pollFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("allInteractions") + void singleSubscriberInCaseOfRacing( + Function> interaction, FrameType requestType) { + for (int i = 1; i < 20000; i += 2) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + RaceTestUtils.race( + () -> response.subscribe(assertSubscriberA), () -> response.subscribe(assertSubscriberB)); + + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), i)); + } + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); + + Assertions.assertThat(new AssertSubscriber[] {assertSubscriberA, assertSubscriberB}) + .anySatisfy(as -> as.assertError(IllegalStateException.class)); + + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.COMPLETE) + .matches(ByteBuf::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + } + connection.clearSendReceiveBuffers(); + } + } + + @ParameterizedTest + @MethodSource("allInteractions") + void singleSubscriberInteractionsAreLazy(Function> interaction) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + + Assertions.assertThat(connection.getSent().size()).isEqualTo(0); + } + + static long requestFramesCount(Collection frames) { + return frames + .stream() + .filter(frame -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(frame))) + .count(); + } + + static Stream allInteractions() { + return Stream.of( + Arguments.of( + (Function>) + rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), + FrameType.REQUEST_FNF), + Arguments.of( + (Function>) + rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), + FrameType.REQUEST_RESPONSE), + Arguments.of( + (Function>) + rSocket -> rSocket.requestStream(DefaultPayload.create("test")), + FrameType.REQUEST_STREAM), + Arguments.of( + (Function>) + rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), + FrameType.REQUEST_CHANNEL), + Arguments.of( + (Function>) + rSocket -> + rSocket.metadataPush( + DefaultPayload.create(new byte[0], "test".getBytes(CharsetUtil.UTF_8))), + FrameType.METADATA_PUSH)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java new file mode 100644 index 000000000..5cfa76a1c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java @@ -0,0 +1,113 @@ +package io.rsocket.core; + +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketRequesterTest.ClientSocketRule; +import io.rsocket.frame.FrameType; +import io.rsocket.util.EmptyPayload; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.Arrays; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public class RSocketRequesterTerminationTest { + + public final ClientSocketRule rule = new ClientSocketRule(); + + @BeforeEach + public void setup() { + rule.init(); + } + + @AfterEach + public void tearDownAndCheckNoLeaks() { + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testCurrentStreamIsTerminatedOnConnectionClose( + FrameType requestType, Function> interaction) { + RSocketRequester rSocket = rule.socket; + + StepVerifier.create(interaction.apply(rSocket)) + .then( + () -> { + FrameAssert.assertThat(rule.connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + }) + .then(() -> rule.connection.dispose()) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testSubsequentStreamIsTerminatedAfterConnectionClose( + FrameType requestType, Function> interaction) { + RSocketRequester rSocket = rule.socket; + + rule.connection.dispose(); + StepVerifier.create(interaction.apply(rSocket)) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(5)); + } + + public static Iterable rsocketInteractions() { + EmptyPayload payload = EmptyPayload.INSTANCE; + + Arguments resp = + Arguments.of( + FrameType.REQUEST_RESPONSE, + new Function>() { + @Override + public Mono apply(RSocket rSocket) { + return rSocket.requestResponse(payload); + } + + @Override + public String toString() { + return "Request Response"; + } + }); + Arguments stream = + Arguments.of( + FrameType.REQUEST_STREAM, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestStream(payload); + } + + @Override + public String toString() { + return "Request Stream"; + } + }); + Arguments channel = + Arguments.of( + FrameType.REQUEST_CHANNEL, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestChannel(Flux.never().startWith(payload)); + } + + @Override + public String toString() { + return "Request Channel"; + } + }); + + return Arrays.asList(resp, stream, channel); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java new file mode 100644 index 000000000..a1199f698 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -0,0 +1,1516 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.ReassemblyUtils.ILLEGAL_REASSEMBLED_PAYLOAD_SIZE; +import static io.rsocket.core.TestRequesterResponderSupport.fixedSizePayload; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.core.TestRequesterResponderSupport.prepareFragments; +import static io.rsocket.core.TestRequesterResponderSupport.randomMetadataOnlyPayload; +import static io.rsocket.core.TestRequesterResponderSupport.randomPayload; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.Scannable; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketRequesterTest { + + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); + rule = new ClientSocketRule(); + rule.init(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testInvalidFrameOnStream0ShouldNotTerminateRSocket() { + rule.connection.addToReceivedBuffer(RequestNFrameCodec.encode(rule.alloc(), 0, 10)); + assertThat(rule.socket.isDisposed()).isFalse(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testStreamInitialN() { + Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); + + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + // don't request here + } + }; + stream.subscribe(subscriber); + + assertThat(rule.connection.getSent()).isEmpty(); + + subscriber.request(5); + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat(sent.size()).describedAs("sent frame count").isEqualTo(1); + + ByteBuf f = sent.get(0); + + assertThat(frameType(f)).describedAs("initial frame").isEqualTo(REQUEST_STREAM); + assertThat(RequestStreamFrameCodec.initialRequestN(f)) + .describedAs("initial request n") + .isEqualTo(5L); + assertThat(f.release()).describedAs("should be released").isEqualTo(true); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleSetupException() { + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); + assertThatThrownBy(() -> rule.socket.onClose().block()) + .isInstanceOf(RejectedSetupException.class); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleApplicationException() { + rule.connection.clearSendReceiveBuffers(); + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(); + response.subscribe(responseSub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), streamId, new ApplicationErrorException("error"))); + + verify(responseSub).onError(any(ApplicationErrorException.class)); + + assertThat(rule.connection.getSent()) + // requestResponseFrame + .hasSize(1) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleValidFrame() { + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeNextReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + + verify(sub).onComplete(); + assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testRequestReplyWithCancel() { + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + + try { + response.block(Duration.ofMillis(100)); + } catch (IllegalStateException ise) { + } + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat(frameType(sent.get(0))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(REQUEST_RESPONSE); + assertThat(frameType(sent.get(1))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(CANCEL); + assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Disabled("invalid") + @Timeout(2_000) + public void testRequestReplyErrorOnSend() { + rule.connection.setAvailability(0); // Fails send + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(10); + response.subscribe(responseSub); + + this.rule + .socket + .onClose() + .as(StepVerifier::create) + .expectComplete() + .verify(Duration.ofMillis(100)); + + verify(responseSub).onSubscribe(any(Subscription.class)); + + rule.assertHasNoLeaks(); + // TODO this should get the error reported through the response subscription + // verify(responseSub).onError(any(RuntimeException.class)); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation() { + Sinks.Empty cancelled = Sinks.empty(); + Flux request = Flux.never().doOnCancel(cancelled::tryEmitEmpty); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.firstWithSignal( + cancelled.asMono(), + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation2() { + Sinks.Empty cancelled = Sinks.empty(); + Flux request = + Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::tryEmitEmpty); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.firstWithSignal( + cancelled.asMono(), + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testChannelRequestServerSideCancellation() { + Sinks.One cancelled = Sinks.one(); + Sinks.Many request = Sinks.many().unicast().onBackpressureBuffer(); + request.tryEmitNext(EmptyPayload.INSTANCE); + rule.socket + .requestChannel(request.asFlux()) + .subscribe(cancelled::tryEmitValue, cancelled::tryEmitError, cancelled::tryEmitEmpty); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(rule.alloc(), streamId)); + rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.alloc(), streamId)); + Flux.firstWithSignal( + cancelled.asMono(), + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + + assertThat(request.scan(Scannable.Attr.TERMINATED) || request.scan(Scannable.Attr.CANCELLED)) + .isTrue(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_CHANNEL) + .matches(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testCorrectFrameOrder() { + Sinks.One delayer = Sinks.one(); + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) {} + }; + rule.socket + .requestChannel( + Flux.concat(Flux.just(0).delayUntil(i -> delayer.asMono()), Flux.range(1, 999)) + .map(i -> DefaultPayload.create(i + ""))) + .subscribe(subscriber); + + subscriber.request(1); + subscriber.request(Long.MAX_VALUE); + delayer.tryEmitEmpty(); + + Iterator iterator = rule.connection.getSent().iterator(); + + ByteBuf initialFrame = iterator.next(); + + assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); + assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)).isEqualTo(Long.MAX_VALUE); + assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) + .isEqualTo("0"); + assertThat(initialFrame.release()).isTrue(); + + assertThat(iterator.hasNext()).isFalse(); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + generator.apply(rule.socket, DefaultPayload.create(data, metadata))) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) + .verify(); + rule.assertHasNoLeaks(); + }); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation1( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + assertThatThrownBy( + () -> { + final Publisher source = + generator.apply(rule.socket, DefaultPayload.create(data, metadata)); + + if (source instanceof Mono) { + ((Mono) source).block(); + } else { + ((Flux) source).blockLast(); + } + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength)); + + rule.assertHasNoLeaks(); + }); + } + + @Test + public void shouldRejectCallOfNoMetadataPayload() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + StepVerifier.create(rule.socket.metadataPush(payload)) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present")) + .verify(); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + + @Test + public void shouldRejectCallOfNoMetadataPayloadBlocking() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + + assertThatThrownBy(() -> rule.socket.metadataPush(payload).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present"); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FrameLengthCodec.FRAME_LENGTH_MASK}) + public void + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + rule.socket.requestChannel( + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata))), + 0) + .expectSubscription() + .thenRequest(2) + .then( + () -> { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2)); + }) + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) + .verify(); + assertThat(rule.connection.getSent()) + // expect to be sent RequestChannelFrame + // expect to be sent CancelFrame + .hasSize(2) + .allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("racingCases") + public void checkNoLeaksOnRacing( + Function> initiator, + BiConsumer, ClientSocketRule> runner) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule clientSocketRule = new ClientSocketRule(); + + clientSocketRule.init(); + + Publisher payloadP = initiator.apply(clientSocketRule); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + if (payloadP instanceof Flux) { + ((Flux) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } else { + ((Mono) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } + + runner.accept(assertSubscriber, clientSocketRule); + + assertThat(clientSocketRule.connection.getSent()).allMatch(ReferenceCounted::release); + + clientSocketRule.assertHasNoLeaks(); + } + } + + private static Stream racingCases() { + return Stream.of( + Arguments.of( + (Function>) + (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + + return rule.socket.requestStream(payload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + assertThat(rule.connection.getSent()).hasSize(2); + assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_STREAM, + "Expected first frame matches {" + + REQUEST_STREAM + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + return rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + sink.complete(); + return ++index; + })); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + int size = rule.connection.getSent().size(); + if (size > 0) { + + assertThat(size).isLessThanOrEqualTo(3).isGreaterThanOrEqualTo(2); + assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_CHANNEL, + "Expected first frame matches {" + + REQUEST_CHANNEL + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + if (size == 2) { + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected second frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } else { + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == COMPLETE || frameType(bb) == CANCEL, + "Expected second frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + assertThat(rule.connection.getSent()) + .element(2) + .matches( + bb -> frameType(bb) == CANCEL || frameType(bb) == COMPLETE, + "Expected third frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(2).findFirst().get()) + + "}"); + } + } + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = CancelFrameCodec.encode(allocator, streamId); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + ErrorFrameCodec.encode(allocator, streamId, new RuntimeException("test")); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBuf data = rule.allocator.buffer(); + data.writeCharSequence("testData", CharsetUtil.UTF_8); + + ByteBuf metadata = rule.allocator.buffer(); + metadata.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + Payload requestPayload = ByteBufPayload.create(data, metadata); + return rule.socket.requestResponse(requestPayload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBuf data = rule.allocator.buffer(); + data.writeCharSequence("testData", CharsetUtil.UTF_8); + + ByteBuf metadata = rule.allocator.buffer(); + metadata.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + Payload requestPayload = ByteBufPayload.create(data, metadata); + return rule.socket.requestStream(requestPayload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, true, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + })); + } + + @Test + public void simpleOnDiscardRequestChannelTest() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); + + assertSubscriber.cancel(); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleOnDiscardRequestChannelTest2() { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); + + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); + + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode( + allocator, streamId, new CustomRSocketException(0x00000404, "test"))); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(responsesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + Publisher response; + + switch (frameType) { + case REQUEST_FNF: + response = + testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p)).then(Mono.empty()); + break; + case REQUEST_RESPONSE: + response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p)); + break; + case REQUEST_STREAM: + response = testPublisher.mono().flatMapMany(p -> rule.socket.requestStream(p)); + break; + case REQUEST_CHANNEL: + response = rule.socket.requestChannel(testPublisher.flux()); + break; + default: + throw new UnsupportedOperationException("illegal case"); + } + + response.subscribe(assertSubscriber); + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d"))); + + int streamId = rule.getStreamIdForRequestType(frameType); + + if (responsesCnt > 0) { + for (int i = 0; i < responsesCnt - 1; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("rd" + (i + 1)).getBytes()))); + } + + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + true, + true, + null, + Unpooled.wrappedBuffer(("rd" + responsesCnt).getBytes()))); + } + + if (framesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, streamId, framesCnt)); + } + + for (int i = 1; i < framesCnt; i++) { + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d" + i))); + } + + assertThat(rule.connection.getSent()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt) + .hasSize(framesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)) + .allMatch(ByteBuf::release); + + assertThat(assertSubscriber.isTerminated()) + .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType) + .isTrue(); + + assertThat(assertSubscriber.values()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames received", + frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(p -> p.release()); + + rule.assertHasNoLeaks(); + rule.connection.clearSendReceiveBuffers(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload( + BiFunction> sourceProducer) { + Payload invalidPayload = + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "test"), + ByteBufUtil.writeUtf8(rule.alloc(), "test")); + invalidPayload.release(); + + Publisher source = sourceProducer.apply(invalidPayload, rule); + + StepVerifier.create(source, 1) + .expectError(IllegalReferenceCountException.class) + .verify(Duration.ofMillis(1000)); + } + + private static Stream>> refCntCases() { + return Stream.of( + (p, clientSocketRule) -> clientSocketRule.socket.fireAndForget(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestResponse(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestStream(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestChannel(Mono.just(p)), + (p, clientSocketRule) -> { + Flux.from(clientSocketRule.connection.getSentAsPublisher()) + .filter(bb -> frameType(bb) == REQUEST_CHANNEL) + .doOnDiscard(ByteBuf.class, ReferenceCounted::release) + .subscribe( + bb -> { + clientSocketRule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + clientSocketRule.allocator, FrameHeaderCodec.streamId(bb), 1)); + bb.release(); + }); + + return clientSocketRule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE, p)); + }); + } + + @Test + public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { + Payload payload1 = ByteBufPayload.create("abc1"); + Mono fnf1 = rule.socket.fireAndForget(payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Mono fnf2 = rule.socket.fireAndForget(payload2); + + assertThat(rule.connection.getSent()).isEmpty(); + + // checks that fnf2 should have id 1 even though it was generated later than fnf1 + AssertSubscriber voidAssertSubscriber2 = fnf2.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber2.assertTerminated().assertNoError(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 1) + // ensures that this is fnf1 with abc2 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes()))) + .matches(ReferenceCounted::release); + + rule.connection.clearSendReceiveBuffers(); + + // checks that fnf1 should have id 3 even though it was generated earlier + AssertSubscriber voidAssertSubscriber1 = fnf1.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber1.assertTerminated().assertNoError(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 3) + // ensures that this is fnf1 with abc1 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes()))) + .matches(ReferenceCounted::release); + } + + @ParameterizedTest + @MethodSource("requestNInteractions") + public void ensuresThatNoOpsMustHappenUntilFirstRequestN( + FrameType frameType, BiFunction> interaction) { + Payload payload1 = ByteBufPayload.create("abc1"); + Publisher interaction1 = interaction.apply(rule, payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Publisher interaction2 = interaction.apply(rule, payload2); + + assertThat(rule.connection.getSent()).isEmpty(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(0); + interaction1.subscribe(assertSubscriber1); + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(0); + interaction2.subscribe(assertSubscriber2); + assertSubscriber1.assertNotTerminated().assertNoError(); + assertSubscriber2.assertNotTerminated().assertNoError(); + // even though we subscribed, nothing should happen until the first requestN + assertThat(rule.connection.getSent()).isEmpty(); + + // first request on the second interaction to ensure that stream id issuing on the first request + assertSubscriber2.request(1); + + assertThat(rule.connection.getSent()) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + + if (frameType == REQUEST_CHANNEL) { + assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } + + rule.connection.clearSendReceiveBuffers(); + + assertSubscriber1.request(1); + assertThat(rule.connection.getSent()) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + + if (frameType == REQUEST_CHANNEL) { + assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } + } + + private static Stream requestNInteractions() { + return Stream.of( + Arguments.of( + REQUEST_RESPONSE, + (BiFunction>) + (rule, payload) -> rule.socket.requestResponse(payload)), + Arguments.of( + REQUEST_STREAM, + (BiFunction>) + (rule, payload) -> rule.socket.requestStream(payload)), + Arguments.of( + REQUEST_CHANNEL, + (BiFunction>) + (rule, payload) -> rule.socket.requestChannel(Flux.just(payload)))); + } + + @ParameterizedTest + @MethodSource("streamRacingCases") + @Disabled("Connection should take care of ordering if such is necessary") + public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( + BiFunction> interaction1, + BiFunction> interaction2, + FrameType interactionType1, + FrameType interactionType2) { + Assumptions.assumeThat(interactionType1).isNotEqualTo(METADATA_PUSH); + Assumptions.assumeThat(interactionType2).isNotEqualTo(METADATA_PUSH); + for (int i = 1; i < RaceTestConstants.REPEATS; i += 4) { + Payload payload = DefaultPayload.create("test", "test"); + Publisher publisher1 = interaction1.apply(rule, payload); + Publisher publisher2 = interaction2.apply(rule, payload); + RaceTestUtils.race( + () -> publisher1.subscribe(AssertSubscriber.create()), + () -> publisher2.subscribe(AssertSubscriber.create())); + + assertThat(rule.connection.getSent()) + .extracting(FrameHeaderCodec::streamId) + .containsExactly(i, i + 2); + rule.connection.getSent().forEach(bb -> bb.release()); + rule.connection.getSent().clear(); + } + } + + public static Stream streamRacingCases() { + return Stream.of( + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + (BiFunction>) + (r, p) -> r.socket.requestResponse(p), + REQUEST_FNF, + REQUEST_RESPONSE), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestResponse(p), + (BiFunction>) + (r, p) -> r.socket.requestStream(p), + REQUEST_RESPONSE, + REQUEST_STREAM), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestStream(p), + (BiFunction>) + (r, p) -> { + AtomicBoolean subscribed = new AtomicBoolean(); + Flux just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true)); + return r.socket + .requestChannel(just) + .doFinally( + __ -> { + if (!subscribed.get()) { + p.release(); + } + }); + }, + REQUEST_STREAM, + REQUEST_CHANNEL), + Arguments.of( + (BiFunction>) + (r, p) -> { + AtomicBoolean subscribed = new AtomicBoolean(); + Flux just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true)); + return r.socket + .requestChannel(just) + .doFinally( + __ -> { + if (!subscribed.get()) { + p.release(); + } + }); + }, + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + REQUEST_CHANNEL, + REQUEST_FNF), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.metadataPush(p), + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + METADATA_PUSH, + REQUEST_FNF)); + } + + @ParameterizedTest + @MethodSource("streamRacingCases") + @SuppressWarnings({"rawtypes", "unchecked"}) + public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( + BiFunction> interaction1, + BiFunction> interaction2, + FrameType interactionType1, + FrameType interactionType2) { + for (int i = 1; i < RaceTestConstants.REPEATS; i++) { + Payload payload1 = ByteBufPayload.create("test", "test"); + Payload payload2 = ByteBufPayload.create("test", "test"); + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(); + Publisher publisher1 = interaction1.apply(rule, payload1); + Publisher publisher2 = interaction2.apply(rule, payload2); + RaceTestUtils.race( + () -> rule.socket.dispose(), + () -> publisher1.subscribe(assertSubscriber1), + () -> publisher2.subscribe(assertSubscriber2)); + + assertSubscriber1.await().assertTerminated(); + if (interactionType1 != REQUEST_FNF && interactionType1 != METADATA_PUSH) { + assertSubscriber1.assertError(ClosedChannelException.class); + } else { + try { + assertSubscriber1.assertError(ClosedChannelException.class); + } catch (Throwable t) { + // fnf call may be completed + assertSubscriber1.assertComplete(); + } + } + assertSubscriber2.await().assertTerminated(); + if (interactionType2 != REQUEST_FNF && interactionType2 != METADATA_PUSH) { + assertSubscriber2.assertError(ClosedChannelException.class); + } else { + try { + assertSubscriber2.assertError(ClosedChannelException.class); + } catch (Throwable t) { + // fnf call may be completed + assertSubscriber2.assertComplete(); + } + } + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.connection.getSent().clear(); + + assertThat(payload1.refCnt()).isZero(); + assertThat(payload2.refCnt()).isZero(); + } + } + + @Test + // see https://github.com/rsocket/rsocket-java/issues/858 + public void testWorkaround858() { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + rule.socket.requestResponse(ByteBufPayload.create(buffer)).subscribe(); + + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); + + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_RESPONSE) + .matches(ByteBuf::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + + rule.assertHasNoLeaks(); + } + + @DisplayName("reassembles data") + @ParameterizedTest + @MethodSource("requestNInteractions") + void reassembleData( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload randomPayload = randomPayload(leaksTrackingByteBufAllocator); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, randomPayload); + + final Publisher responsePublisher = requestFunction.apply(rule, requestPayload); + StepVerifier.create(responsePublisher) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .assertNext( + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .thenCancel() + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + if (!rule.connection.getSent().isEmpty()) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + } + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @DisplayName("reassembles metadata") + @ParameterizedTest + @MethodSource("requestNInteractions") + void reassembleMetadata( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload metadataOnlyPayload = randomMetadataOnlyPayload(leaksTrackingByteBufAllocator); + List fragments = + prepareFragments(leaksTrackingByteBufAllocator, mtu, metadataOnlyPayload); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .assertNext( + responsePayload -> { + PayloadAssert.assertThat(responsePayload).isEqualTo(metadataOnlyPayload).hasNoLeaks(); + metadataOnlyPayload.release(); + }) + .thenCancel() + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + if (!rule.connection.getSent().isEmpty()) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + } + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if reassembling payload size exceeds {0}") + @MethodSource("requestNInteractions") + public void errorTooBigPayload( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final int maxInboundPayloadSize = ThreadLocalRandom.current().nextInt(mtu + 1, 4096); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload responsePayload = + fixedSizePayload(leaksTrackingByteBufAllocator, maxInboundPayloadSize + 1); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, responsePayload); + responsePayload.release(); + + rule.setMaxInboundPayloadSize(maxInboundPayloadSize); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .expectErrorMessage(String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)) + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if fragment before the last is < min MTU {0}") + @MethodSource("requestNInteractions") + public void errorFragmentTooSmall( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = 32; + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload responsePayload = fixedSizePayload(leaksTrackingByteBufAllocator, 156); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, responsePayload); + responsePayload.release(); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .expectErrorMessage("Fragment is too small.") + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(strings = {"stream", "channel"}) + // see https://github.com/rsocket/rsocket-java/issues/959 + public void testWorkaround959(String type) { + for (int i = 1; i < 20000; i += 2) { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(3); + if (type.equals("stream")) { + rule.socket.requestStream(ByteBufPayload.create(buffer)).subscribe(assertSubscriber); + } else if (type.equals("channel")) { + rule.socket + .requestChannel(Flux.just(ByteBufPayload.create(buffer))) + .subscribe(assertSubscriber); + } + + final ByteBuf payloadFrame = + PayloadFrameCodec.encode( + rule.alloc(), i, false, false, true, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame); + }, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + + assertSubscriber.values().forEach(ReferenceCountUtil::safeRelease); + assertSubscriber.assertNoError(); + + rule.connection.clearSendReceiveBuffers(); + rule.assertHasNoLeaks(); + } + } + + public static class ClientSocketRule extends AbstractSocketRule { + + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @Override + protected RSocketRequester newRSocket() { + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + return new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + maxFrameLength, + maxInboundPayloadSize, + Integer.MAX_VALUE, + Integer.MAX_VALUE, + null, + (__) -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + } + + public int getStreamIdForRequestType(FrameType expectedFrameType) { + assertThat(connection.getSent().size()) + .describedAs("Unexpected frames sent.") + .isGreaterThanOrEqualTo(1); + List framesFound = new ArrayList<>(); + for (ByteBuf frame : connection.getSent()) { + FrameType frameType = frameType(frame); + if (frameType == expectedFrameType) { + return FrameHeaderCodec.streamId(frame); + } + framesFound.add(frameType); + } + throw new AssertionError( + "No frames sent with frame type: " + + expectedFrameType + + ", frames found: " + + framesFound); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java new file mode 100644 index 000000000..4f689e396 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -0,0 +1,1269 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.ReassemblyUtils.ILLEGAL_REASSEMBLED_PAYLOAD_SIZE; +import static io.rsocket.core.TestRequesterResponderSupport.fixedSizePayload; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.core.TestRequesterResponderSupport.prepareFragments; +import static io.rsocket.core.TestRequesterResponderSupport.randomMetadataOnlyPayload; +import static io.rsocket.core.TestRequesterResponderSupport.randomPayload; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.NEXT; +import static io.rsocket.frame.FrameType.NEXT_COMPLETE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketResponderTest { + + ServerSocketRule rule; + + @BeforeEach + public void setUp() { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped(t -> {}); + rule = new ServerSocketRule(); + rule.init(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleKeepAlive() { + rule.connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); + ByteBuf sent = rule.connection.awaitFrame(); + assertThat(frameType(sent)) + .describedAs("Unexpected frame sent.") + .isEqualTo(FrameType.KEEPALIVE); + /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ + assertThat(KeepAliveFrameCodec.respondFlag(sent)) + .describedAs("Unexpected keep-alive frame respond flag.") + .isEqualTo(false); + } + + @Test + @Timeout(2_000) + public void testHandleResponseFrameNoError() { + final int streamId = 4; + rule.connection.clearSendReceiveBuffers(); + final TestPublisher testPublisher = TestPublisher.create(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return testPublisher.mono(); + } + }); + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + testPublisher.complete(); + FrameAssert.assertThat(rule.connection.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + testPublisher.assertWasNotCancelled(); + } + + @Test + @Timeout(2_000) + public void testHandlerEmitsError() { + final int streamId = 4; + rule.prefetch = 1; + rule.sendRequest(streamId, FrameType.REQUEST_STREAM); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("Request-Stream not implemented.") + .hasNoLeaks(); + } + + @Test + @Timeout(20_000) + public void testCancel() { + ByteBufAllocator allocator = rule.alloc(); + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return Mono.never().doOnCancel(() -> cancelled.set(true)); + } + }); + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); + + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(allocator, streamId)); + + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + @Timeout(2_000) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + final RSocket acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload p) { + p.release(); + return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestStream(Payload p) { + p.release(); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .doOnNext(Payload::release) + .subscribe( + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + subscription.request(1); + } + }); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + }; + rule.setAcceptingSocket(acceptingSocket); + + final Runnable[] runnables = { + () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + }; + + for (Runnable runnable : runnables) { + rule.connection.clearSendReceiveBuffers(); + runnable.run(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.ERROR) + .matches( + bb -> + ErrorFrameCodec.dataUtf8(bb) + .contains(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) + .matches(ReferenceCounted::release); + + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); + } + + rule.assertHasNoLeaks(); + } + + @Test + public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + final Sinks.One sink = Sinks.one(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return sink.asMono().flux(); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, true, true, metadata3, data3); + + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + () -> { + assertSubscriber.cancel(); + sink.tryEmitEmpty(); + }); + + assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnComplete(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); + rule.assertHasNoLeaks(); + } + } + + @Test + public void + checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + FluxSink[] sinks = new FluxSink[1]; + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + + return Flux.create( + sink -> { + sinks[0] = sink; + }, + FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + + ByteBuf m1 = allocator.buffer(); + m1.writeCharSequence("m1", CharsetUtil.UTF_8); + ByteBuf d1 = allocator.buffer(); + d1.writeCharSequence("d1", CharsetUtil.UTF_8); + Payload np1 = ByteBufPayload.create(d1, m1); + + ByteBuf m2 = allocator.buffer(); + m2.writeCharSequence("m2", CharsetUtil.UTF_8); + ByteBuf d2 = allocator.buffer(); + d2.writeCharSequence("d2", CharsetUtil.UTF_8); + Payload np2 = ByteBufPayload.create(d2, m2); + + ByteBuf m3 = allocator.buffer(); + m3.writeCharSequence("m3", CharsetUtil.UTF_8); + ByteBuf d3 = allocator.buffer(); + d3.writeCharSequence("d3", CharsetUtil.UTF_8); + Payload np3 = ByteBufPayload.create(d3, m3); + + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + () -> { + sink.next(np1); + sink.next(np2); + sink.next(np3); + sink.error(new RuntimeException()); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + assertThat(assertSubscriber.values()) + .allMatch( + msg -> { + ReferenceCountUtil.safeRelease(msg); + return msg.refCnt() == 0; + }); + rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnError(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + + testRequestInterceptor.expectOnStart(1, REQUEST_STREAM).expectOnCancel(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + sources[0] = new Operators.MonoSubscriber<>(actual); + actual.onSubscribe(sources[0]); + } + }; + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_RESPONSE); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sources[0].complete(ByteBufPayload.create("d1", "m1")); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, REQUEST_RESPONSE) + .assertNext( + e -> + assertThat(e.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_CANCEL)) + .expectNothing(); + } + } + + @Test + public void simpleDiscardRequestStreamTest() { + ByteBufAllocator allocator = rule.alloc(); + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + rule.connection.addToReceivedBuffer(cancelFrame); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleDiscardRequestChannelTest() { + ByteBufAllocator allocator = rule.alloc(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return (Flux) payloads; + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("de3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + + rule.connection.addToReceivedBuffer(cancelFrame); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(framesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return testPublisher.flux(); + } + }, + 1); + + rule.sendRequest(1, frameType, ByteBufPayload.create("d")); + + // if responses number is bigger than 1 we have to send one extra requestN + if (responsesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, 1, responsesCnt - 1)); + } + + // respond with specific number of elements + for (int i = 0; i < responsesCnt; i++) { + testPublisher.next(ByteBufPayload.create("rd" + i)); + } + + // Listen to incoming frames. Valid for RequestChannel case only + if (framesCnt > 1) { + for (int i = 1; i < responsesCnt; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("d" + (i + 1)).getBytes()))); + } + } + + if (responsesCnt > 0) { + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)); + } + + if (framesCnt > 1) { + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe single RequestN(%s) frame", + frameType, framesCnt - 1) + .hasSize(1) + .first() + .matches(bb -> RequestNFrameCodec.requestN(bb) == (framesCnt - 1)); + } + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) + .hasSize(framesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload(FrameType frameType) { + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Mono.just(invalidPayload); + } + + @Override + public Flux requestStream(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + }); + + rule.sendRequest(1, frameType); + + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches( + bb -> frameType(bb) == ERROR, + "Expect frame type to be {" + + ERROR + + "} but was {" + + frameType(rule.connection.getSent().iterator().next()) + + "}") + .matches(ByteBuf::release); + } + + private static Stream refCntCases() { + return Stream.of(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + + @Test + // see https://github.com/rsocket/rsocket-java/issues/858 + public void testWorkaround858() { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + TestPublisher testPublisher = TestPublisher.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).doOnNext(ReferenceCounted::release).subscribe(); + + return testPublisher.flux(); + } + }); + + rule.connection.addToReceivedBuffer( + RequestChannelFrameCodec.encodeReleasingPayload( + rule.alloc(), 1, false, 1, ByteBufPayload.create(buffer))); + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); + + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_N) + .matches(ReferenceCounted::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + testPublisher.assertWasCancelled(); + + rule.assertHasNoLeaks(); + } + + static Stream requestCases() { + return Stream.of(REQUEST_FNF, REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + + @DisplayName("reassembles payload") + @ParameterizedTest + @MethodSource("requestCases") + void reassemblePayload(FrameType frameType) { + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final Payload randomPayload = randomPayload(rule.allocator); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(frameType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasNoLeaks(); + if (frameType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @DisplayName("reassembles metadata") + @ParameterizedTest + @MethodSource("requestCases") + void reassembleMetadataOnly(FrameType frameType) { + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final Payload randomMetadataOnlyPayload = randomMetadataOnlyPayload(rule.allocator); + List fragments = + prepareFragments(rule.allocator, mtu, randomMetadataOnlyPayload, frameType); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()) + .isEqualTo(randomMetadataOnlyPayload) + .hasNoLeaks(); + randomMetadataOnlyPayload.release(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(frameType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasNoLeaks(); + if (frameType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if reassembling payload size exceeds {0}") + @MethodSource("requestCases") + public void errorTooBigPayload(FrameType frameType) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final int maxInboundPayloadSize = ThreadLocalRandom.current().nextInt(mtu + 1, 4096); + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setMaxInboundPayloadSize(maxInboundPayloadSize); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + final Payload randomPayload = fixedSizePayload(rule.allocator, maxInboundPayloadSize + 1); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + randomPayload.release(); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isNull(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(ERROR) + .hasData( + "Failed to reassemble payload. Cause: " + + String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)) + .hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if fragment before the last is < min MTU {0}") + @MethodSource("requestCases") + public void errorFragmentTooSmall(FrameType frameType) { + final int mtu = 32; + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + final Payload randomPayload = fixedSizePayload(rule.allocator, 156); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + randomPayload.release(); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isNull(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(ERROR) + .hasData("Failed to reassemble payload. Cause: Fragment is too small.") + .hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("requestCases") + void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored_ReassemblyCase( + FrameType requestType) { + AtomicReference receivedPayload = new AtomicReference<>(); + final Sinks.Empty delayer = Sinks.empty(); + rule.setAcceptingSocket( + new RSocket() { + + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return delayer.asMono(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + }); + final Payload randomPayload1 = fixedSizePayload(rule.allocator, 128); + final List fragments1 = + prepareFragments(rule.allocator, 64, randomPayload1, requestType); + final Payload randomPayload2 = fixedSizePayload(rule.allocator, 128); + final List fragments2 = + prepareFragments(rule.allocator, 64, randomPayload2, requestType); + randomPayload2.release(); + rule.connection.addToReceivedBuffer(fragments1.remove(0)); + rule.connection.addToReceivedBuffer(fragments2.remove(0)); + + rule.connection.addToReceivedBuffer(fragments1.toArray(new ByteBuf[0])); + if (requestType != REQUEST_CHANNEL) { + rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); + delayer.tryEmitEmpty(); + } else { + delayer.tryEmitEmpty(); + rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.allocator, 1)); + rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); + } + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload1).hasNoLeaks(); + randomPayload1.release(); + + if (requestType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(requestType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasNoLeaks(); + + if (requestType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("requestCases") + void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored(FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(REQUEST_FNF); + AtomicReference receivedPayload = new AtomicReference<>(); + final Sinks.One delayer = Sinks.one(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + }); + final Payload randomPayload1 = fixedSizePayload(rule.allocator, 64); + final Payload randomPayload2 = fixedSizePayload(rule.allocator, 64); + rule.sendRequest(1, requestType, randomPayload1.retain()); + rule.sendRequest(1, requestType, randomPayload2); + + delayer.tryEmitEmpty(); + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload1).hasNoLeaks(); + randomPayload1.release(); + + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(requestType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasNoLeaks(); + + if (requestType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + public static class ServerSocketRule extends AbstractSocketRule { + + private RSocket acceptingSocket; + private volatile int prefetch; + private RequestInterceptor requestInterceptor; + protected Sinks.Empty onCloseSink; + + @Override + protected void doInit() { + acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + }; + super.doInit(); + } + + public void setAcceptingSocket(RSocket acceptingSocket) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = Integer.MAX_VALUE; + super.doInit(); + } + + public void setRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + super.doInit(); + } + + public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = prefetch; + super.doInit(); + } + + @Override + protected RSocketResponder newRSocket() { + onCloseSink = Sinks.empty(); + return new RSocketResponder( + connection, + acceptingSocket, + PayloadDecoder.ZERO_COPY, + null, + 0, + maxFrameLength, + maxInboundPayloadSize, + __ -> requestInterceptor, + onCloseSink); + } + + private void sendRequest(int streamId, FrameType frameType) { + sendRequest(streamId, frameType, EmptyPayload.INSTANCE); + } + + private void sendRequest(int streamId, FrameType frameType, Payload payload) { + ByteBuf request; + + switch (frameType) { + case REQUEST_CHANNEL: + request = + RequestChannelFrameCodec.encodeReleasingPayload( + allocator, streamId, false, prefetch, payload); + break; + case REQUEST_STREAM: + request = + RequestStreamFrameCodec.encodeReleasingPayload( + allocator, streamId, prefetch, payload); + break; + case REQUEST_RESPONSE: + request = RequestResponseFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + case REQUEST_FNF: + request = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + default: + throw new IllegalArgumentException("unsupported type: " + frameType); + } + + connection.addToReceivedBuffer(request); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java new file mode 100644 index 000000000..90e881257 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java @@ -0,0 +1,64 @@ +package io.rsocket.core; + +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestServerTransport; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class RSocketServerFragmentationTest { + + @Test + public void serverErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketServer.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void serverSucceedsWithEnabledFragmentationOnSufficientMtu() { + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().fragment(100).bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void serverSucceedsWithDisabledFragmentation() { + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void clientErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketConnector.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void clientSucceedsWithEnabledFragmentationOnSufficientMtu() { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create().fragment(100).connect(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void clientSucceedsWithDisabledFragmentation() { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.connectWith(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java new file mode 100644 index 000000000..a335ac1f3 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestServerTransport; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Random; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class RSocketServerTest { + + @Test + public void unexpectedFramesBeforeSetupFrame() { + TestServerTransport transport = new TestServerTransport(); + RSocketServer.create().bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("SETUP or RESUME frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + duplexConnection.alloc().assertHasNoLeaks(); + } + + @Test + public void timeoutOnNoFirstFrame() { + final VirtualTimeScheduler scheduler = VirtualTimeScheduler.getOrSet(); + TestServerTransport transport = new TestServerTransport(); + try { + RSocketServer.create().maxTimeToFirstFrame(Duration.ofMinutes(2)).bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + Assertions.assertThat(duplexConnection.isDisposed()).isFalse(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + } finally { + transport.alloc().assertHasNoLeaks(); + VirtualTimeScheduler.reset(); + } + } + + @Test + public void ensuresMaxFrameLengthCanNotBeLessThenMtu() { + RSocketServer.create() + .fragment(128) + .bind(new TestServerTransport().withMaxFrameLength(64)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maximumTransmissionUnit[128] exceeds configured maxFrameLength[64]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPayloadSize() { + RSocketServer.create() + .maxInboundPayloadSize(128) + .bind(new TestServerTransport().withMaxFrameLength(256)) + .as(StepVerifier::create) + .expectErrorMessage("Configured maxFrameLength[256] exceeds maxPayloadSize[128]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPossibleFrameLength() { + RSocketServer.create() + .bind(new TestServerTransport().withMaxFrameLength(Integer.MAX_VALUE)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maxFrameLength[" + + Integer.MAX_VALUE + + "] " + + "exceeds maxFrameLength limit " + + FRAME_LENGTH_MASK) + .verify(); + } + + @Test + public void unexpectedFramesBeforeSetup() { + Sinks.Empty connectedSink = Sinks.empty(); + + TestServerTransport transport = new TestServerTransport(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> { + connectedSink.tryEmitEmpty(); + return Mono.just(new RSocket() {}); + }) + .bind(transport) + .block(); + + byte[] bytes = new byte[16_000_000]; + new Random().nextBytes(bytes); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer( + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.EMPTY_BUFFER, + ByteBufAllocator.DEFAULT.buffer(bytes.length).writeBytes(bytes))); + + StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); + assertThat(connectedSink.scan(Scannable.Attr.TERMINATED)) + .as("Connection should not succeed") + .isFalse(); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("SETUP or RESUME frame must be received before any others") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresErrorFrameDeliveredPriorConnectionDisposal() { + TestServerTransport transport = new TestServerTransport(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> Mono.error(new RejectedSetupException("ACCESS_DENIED"))) + .bind(transport) + .block(); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer( + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + false, + 0, + 1, + Unpooled.EMPTY_BUFFER, + "metadata_type", + "data_type", + EmptyPayload.INSTANCE)); + + StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("ACCESS_DENIED") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java new file mode 100644 index 000000000..e01e6ebdc --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -0,0 +1,605 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.LocalDuplexConnection; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReference; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Publisher; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; + +public class RSocketTest { + + public final SocketRule rule = new SocketRule(); + + @BeforeEach + public void setup() { + rule.init(); + } + + @AfterEach + public void tearDownAndCheckOnLeaks() { + rule.alloc().assertHasNoLeaks(); + } + + @Test + public void rsocketDisposalShouldEndupWithNoErrorsOnClose() { + RSocket requestHandlingRSocket = + new RSocket() { + final Disposable disposable = Disposables.single(); + + @Override + public void dispose() { + disposable.dispose(); + } + + @Override + public boolean isDisposed() { + return disposable.isDisposed(); + } + }; + rule.setRequestAcceptor(requestHandlingRSocket); + rule.crs + .onClose() + .as(StepVerifier::create) + .expectSubscription() + .then(rule.crs::dispose) + .expectComplete() + .verify(Duration.ofMillis(100)); + + Assertions.assertThat(requestHandlingRSocket.isDisposed()).isTrue(); + } + + @Test + @Timeout(2_000) + public void testRequestReplyNoError() { + StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) + .expectNextCount(1) + .expectComplete() + .verify(); + } + + @Test + @Timeout(2000) + public void testHandlerEmitsError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error(new NullPointerException("Deliberate exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("Deliberate exception.")) + .verify(Duration.ofMillis(100)); + } + + @Test + @Timeout(2000) + public void testHandlerEmitsCustomError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error( + new CustomRSocketException(0x00000501, "Deliberate Custom exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("Deliberate Custom exception.") + .hasFieldOrPropertyWithValue("errorCode", 0x00000501)) + .verify(); + } + + @Test + @Timeout(2000) + public void testRequestPropagatesCorrectlyForRequestChannel() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + // specifically limits request to 3 in order to prevent 256 request from limitRate + // hidden on the responder side + .take(3, true); + } + }); + + Flux.range(0, 3) + .map(i -> DefaultPayload.create("" + i)) + .as(rule.crs::requestChannel) + .as(publisher -> StepVerifier.create(publisher, 3)) + .expectSubscription() + .expectNextCount(3) + .expectComplete() + .verify(Duration.ofMillis(5000)); + } + + @Test + @Timeout(2000) + public void testStream() { + Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test + @Timeout(200000) + public void testChannel() { + Flux requests = + Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test + @Timeout(2000) + public void testErrorPropagatesCorrectly() { + AtomicReference error = new AtomicReference<>(); + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).doOnError(error::set); + } + }); + Flux requests = Flux.error(new RuntimeException("test")); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectErrorMessage("test").verify(); + Assertions.assertThat(error.get()).isNull(); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion1() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion2() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_CancellationFromResponderShouldLeaveStreamInHalfClosedStateWithNextCompletionPossibleFromRequester() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void + requestChannelCase_CompletionFromRequesterShouldLeaveStreamInHalfClosedStateWithNextCancellationPossibleFromResponder() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_ensureThatRequesterSubscriberCancellationTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + cancelFromRequesterSubscriber( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + @Test + public void requestChannelCase_ErrorFromResponderShouldTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + errorFromResponderPublisher( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + @Test + public void requestChannelCase_ErrorFromRequesterShouldTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + errorFromRequesterPublisher( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + void initRequestChannelCase( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(responderSubscriber); + return responderPublisher.flux(); + } + }); + + rule.crs.requestChannel(requesterPublisher).subscribe(requesterSubscriber); + + requesterPublisher.assertWasSubscribed(); + requesterSubscriber.assertSubscribed(); + + responderSubscriber.assertNotSubscribed(); + responderPublisher.assertWasNotSubscribed(); + + // firstRequest + requesterSubscriber.request(1); + requesterPublisher.assertMaxRequested(1); + requesterPublisher.next(DefaultPayload.create("initialData", "initialMetadata")); + + responderSubscriber.assertSubscribed(); + responderPublisher.assertWasSubscribed(); + } + + void nextFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that outerUpstream and innerSubscriber is not terminated so the requestChannel + requesterPublisher.assertSubscribers(1); + responderSubscriber.assertNotTerminated(); + + responderSubscriber.request(6); + requesterPublisher.next( + DefaultPayload.create("d1", "m1"), + DefaultPayload.create("d2"), + DefaultPayload.create("d3", "m3"), + DefaultPayload.create("d4"), + DefaultPayload.create("d5", "m5")); + + List innerPayloads = responderSubscriber.awaitAndAssertNextValueCount(6).values(); + Assertions.assertThat(innerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("initialData", "d1", "d2", "d3", "d4", "d5"); + Assertions.assertThat(innerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, true, false, true, false, true); + Assertions.assertThat(innerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("initialMetadata", "m1", "", "m3", "", "m5"); + } + + void completeFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + requesterPublisher.complete(); + responderSubscriber.assertTerminated(); + requesterPublisher.assertNoSubscribers(); + } + + void cancelFromResponderSubscriber( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + responderSubscriber.cancel(); + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + void nextFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that downstream is not terminated so the requestChannel state is half-closed + responderPublisher.assertSubscribers(1); + requesterSubscriber.assertNotTerminated(); + + // ensures responderPublisher can send messages and outerSubscriber can receive them + requesterSubscriber.request(5); + responderPublisher.next( + DefaultPayload.create("rd1", "rm1"), + DefaultPayload.create("rd2"), + DefaultPayload.create("rd3", "rm3"), + DefaultPayload.create("rd4"), + DefaultPayload.create("rd5", "rm5")); + + List outerPayloads = requesterSubscriber.awaitAndAssertNextValueCount(5).values(); + Assertions.assertThat(outerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("rd1", "rd2", "rd3", "rd4", "rd5"); + Assertions.assertThat(outerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, false, true, false, true); + Assertions.assertThat(outerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("rm1", "", "rm3", "", "rm5"); + } + + void completeFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that after sending complete inner upstream is closed + responderPublisher.complete(); + requesterSubscriber.assertTerminated(); + responderPublisher.assertNoSubscribers(); + } + + void cancelFromRequesterSubscriber( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + requesterSubscriber.cancel(); + // error should be propagated + responderSubscriber.assertTerminated(); + responderPublisher.assertWasCancelled(); + responderPublisher.assertNoSubscribers(); + // ensures that cancellation is propagated to the actual upstream + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + static final CustomRSocketException EXCEPTION = new CustomRSocketException(123456, "test"); + + void errorFromResponderPublisher( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + responderPublisher.error(EXCEPTION); + // error should be propagated + responderSubscriber.assertTerminated().assertError(CancellationException.class); + requesterSubscriber + .assertTerminated() + .assertError(CustomRSocketException.class) + .assertErrorMessage("test"); + // ensures that cancellation is propagated to the actual upstream + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + void errorFromRequesterPublisher( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + requesterPublisher.error(EXCEPTION); + // error should be propagated + responderSubscriber + .assertTerminated() + .assertError(CustomRSocketException.class) + .assertErrorMessage("test"); + requesterSubscriber + .assertTerminated() + .assertError(CustomRSocketException.class) + .assertErrorMessage("test"); + + // ensures that cancellation is propagated to the actual upstream + responderPublisher.assertWasCancelled(); + responderPublisher.assertNoSubscribers(); + } + + public static class SocketRule { + + Sinks.Many serverProcessor; + Sinks.Many clientProcessor; + private RSocketRequester crs; + + @SuppressWarnings("unused") + private RSocketResponder srs; + + private RSocket requestAcceptor; + + private LeaksTrackingByteBufAllocator allocator; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public void init() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + serverProcessor = Sinks.many().multicast().directBestEffort(); + clientProcessor = Sinks.many().multicast().directBestEffort(); + + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + + LocalDuplexConnection serverConnection = + new LocalDuplexConnection("server", allocator, clientProcessor, serverProcessor); + LocalDuplexConnection clientConnection = + new LocalDuplexConnection("client", allocator, serverProcessor, clientProcessor); + + clientConnection.onClose().doFinally(__ -> serverConnection.dispose()).subscribe(); + serverConnection.onClose().doFinally(__ -> clientConnection.dispose()).subscribe(); + + requestAcceptor = + null != requestAcceptor + ? requestAcceptor + : new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 10) + .map(i -> DefaultPayload.create("server got -> [" + payload + "]")); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")) + .subscribe(); + + return Flux.range(1, 10) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + }; + + srs = + new RSocketResponder( + serverConnection, + requestAcceptor, + PayloadDecoder.DEFAULT, + null, + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); + + crs = + new RSocketRequester( + clientConnection, + PayloadDecoder.DEFAULT, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + } + + public void setRequestAcceptor(RSocket requestAcceptor) { + this.requestAcceptor = requestAcceptor; + init(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java new file mode 100644 index 000000000..3112a0943 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java @@ -0,0 +1,1108 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ReconnectMonoTests { + + private Queue retries = new ConcurrentLinkedQueue<>(); + private Queue> received = new ConcurrentLinkedQueue<>(); + private Queue expired = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldExpireValueOnRacingDisposeAndNext() { + Hooks.onErrorDropped(t -> {}); + Hooks.onNextDropped(System.out::println); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final CoreSubscriber[] monoSubscribers = new CoreSubscriber[1]; + Subscription mockSubscription = Mockito.mock(Subscription.class); + final Mono stringMono = + new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + actual.onSubscribe(mockSubscription); + monoSubscribers[0] = actual; + } + }; + + final ReconnectMono reconnectMono = + stringMono + .doOnDiscard(Object.class, System.out::println) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + RaceTestUtils.race(() -> monoSubscribers[0].onNext("value" + index), reconnectMono::dispose); + + monoSubscribers[0].onComplete(); + + subscriber.assertTerminated(); + Mockito.verify(mockSubscription).cancel(); + + if (!subscriber.errors().isEmpty()) { + subscriber + .assertError(CancellationException.class) + .assertErrorMessage("ReconnectMono has already been disposed"); + + assertThat(expired).containsOnly("value" + i); + } else { + subscriber.assertValues("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(raceSubscriber)); + + subscriber.assertTerminated(); + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, subscriber))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + + RaceTestUtils.race( + reconnectMono::invalidate, + () -> { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { + reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_not_expire" + index); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + } + }); + + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); + + raceSubscriber.assertComplete(); + String v = raceSubscriber.values().get(0); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + assertThat(v).isEqualTo("value_to_not_expire" + index); + } else { + assertThat(v).isEqualTo("value_to_expire" + index); + } + + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + assertThat(received) + .hasSize(2) + .containsExactly( + Tuples.of("value_to_expire" + i, reconnectMono), + Tuples.of("value_to_not_expire" + i, reconnectMono)); + } else { + assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + + RaceTestUtils.race( + reconnectMono::invalidate, + reconnectMono::invalidate, + () -> { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { + reconnectMono.resolvingInner.mainSubscriber.onNext( + "value_to_possibly_expire" + index); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + } + }); + + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); + + raceSubscriber.assertComplete(); + assertThat(raceSubscriber.values().get(0)) + .isIn("value_to_possibly_expire" + index, "value_to_expire" + index); + + if (expired.size() == 2) { + assertThat(expired) + .hasSize(2) + .containsExactly("value_to_expire" + i, "value_to_possibly_expire" + i); + } else { + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + } + if (received.size() == 2) { + assertThat(received) + .hasSize(2) + .containsExactly( + Tuples.of("value_to_expire" + i, reconnectMono), + Tuples.of("value_to_possibly_expire" + i, reconnectMono)); + } else { + assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final Mono source = + Mono.fromSupplier( + new Supplier() { + boolean once = false; + + @Override + public String get() { + + if (!once) { + once = true; + return "value_to_expire" + index; + } + + return "value_to_not_expire" + index; + } + }); + + final ReconnectMono reconnectMono = + new ReconnectMono<>( + source.subscribeOn(Schedulers.boundedElastic()), onExpire(), onValue()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + subscriber.await().assertComplete(); + + assertThat(expired).isEmpty(); + + RaceTestUtils.race( + () -> + assertThat(reconnectMono.block()) + .matches( + (v) -> + v.equals("value_to_not_expire" + index) + || v.equals("value_to_expire" + index)), + reconnectMono::invalidate); + + subscriber.assertTerminated(); + + subscriber.assertValues("value_to_expire" + i); + + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + await().atMost(Duration.ofSeconds(5)).until(() -> received.size() == 2); + assertThat(received) + .hasSize(2) + .containsExactly( + Tuples.of("value_to_expire" + i, reconnectMono), + Tuples.of("value_to_not_expire" + i, reconnectMono)); + } else { + assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = new AssertSubscriber<>(); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + assertThat(cold.subscribeCount()).isZero(); + + RaceTestUtils.race( + () -> reconnectMono.subscribe(subscriber), () -> reconnectMono.subscribe(raceSubscriber)); + + subscriber.assertTerminated(); + assertThat(raceSubscriber.isTerminated()).isTrue(); + + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat(cold.subscribeCount()).isOne(); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, subscriber))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + assertThat(cold.subscribeCount()).isZero(); + + String[] values = new String[1]; + + RaceTestUtils.race( + () -> values[0] = reconnectMono.block(timeout), + () -> reconnectMono.subscribe(subscriber)); + + subscriber.assertTerminated(); + + subscriber.assertValues("value" + i); + assertThat(values).containsExactly("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat(cold.subscribeCount()).isOne(); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, subscriber))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + assertThat(cold.subscribeCount()).isZero(); + + String[] values1 = new String[1]; + String[] values2 = new String[1]; + + RaceTestUtils.race( + () -> values1[0] = reconnectMono.block(timeout), + () -> values2[0] = reconnectMono.block(timeout)); + + assertThat(values2).containsExactly("value" + i); + assertThat(values1).containsExactly("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat(cold.subscribeCount()).isOne(); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, new AssertSubscriber<>()))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + subscriber.assertTerminated(); + + Throwable error = subscriber.errors().get(0); + + if (error instanceof CancellationException) { + assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(error) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Source completed empty"); + } + + assertThat(expired).isEmpty(); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + assertThat(subscriber.errors().get(0)) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); + } + + assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndError() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("test"); + } + } else { + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); + } + + assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono() + .retryWhen(Retry.max(1).filter(t -> t instanceof Exception)) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(error).matches(Exceptions::isRetryExhausted).hasCause(runtimeException); + } + + assertThat(expired).hasSize(1).containsOnly("value" + i); + } else { + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldThrowOnBlocking() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on Mono blocking read"); + } + + @Test + public void shouldThrowOnBlockingIfHasAlreadyTerminated() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + publisher.error(new RuntimeException("test")); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(RuntimeException.class) + .hasMessage("test") + .hasSuppressedException(new Exception("Terminated with an error")); + } + + @Test + public void shouldBeScannable() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final Mono parent = publisher.mono(); + final ReconnectMono reconnectMono = + parent.as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final Scannable scannableOfReconnect = Scannable.from(reconnectMono); + + assertThat( + (List) + scannableOfReconnect.parents().map(s -> s.getClass()).collect(Collectors.toList())) + .hasSize(1) + .containsExactly(publisher.mono().getClass()); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(false); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + final Scannable scannableOfMonoProcessor = Scannable.from(subscriber); + + assertThat( + (List) + scannableOfMonoProcessor + .parents() + .map(s -> s.getClass()) + .collect(Collectors.toList())) + .hasSize(4) + .containsExactly( + ResolvingOperator.MonoDeferredResolutionOperator.class, + ReconnectMono.ResolvingInner.class, + ReconnectMono.class, + publisher.mono().getClass()); + + reconnectMono.dispose(); + + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(true); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) + .isInstanceOf(CancellationException.class); + } + + @Test + public void shouldNotExpiredIfNotCompleted() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + AssertSubscriber subscriber = new AssertSubscriber<>(); + + reconnectMono.subscribe(subscriber); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.next("test"); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + reconnectMono.invalidate(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + publisher.assertSubscribers(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); + + publisher.complete(); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); + + publisher.assertSubscribers(0); + assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldNotEmitUntilCompletion() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + AssertSubscriber subscriber = new AssertSubscriber<>(); + + reconnectMono.subscribe(subscriber); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.next("test"); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.complete(); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); + subscriber.assertValues("test"); + } + + @Test + public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + AssertSubscriber subscriber = new AssertSubscriber<>(); + + reconnectMono.subscribe(subscriber); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.next("test"); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + subscriber.cancel(); + + assertThat(reconnectMono.resolvingInner.subscribers) + .isEqualTo(ResolvingOperator.EMPTY_SUBSCRIBED); + + publisher.complete(); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + assertThat(subscriber.values()).isEmpty(); + } + + @Test + public void shouldExpireValueOnDispose() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono) + .expectSubscription() + .then(() -> publisher.next("value")) + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + + reconnectMono.dispose(); + + assertThat(expired).hasSize(1); + assertThat(received).hasSize(1); + assertThat(reconnectMono.isDisposed()).isTrue(); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectError(CancellationException.class) + .verify(Duration.ofSeconds(timeout)); + } + + @Test + public void shouldNotifyAllTheSubscribers() { + final TestPublisher publisher = TestPublisher.create(); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber sub1 = new AssertSubscriber<>(); + final AssertSubscriber sub2 = new AssertSubscriber<>(); + final AssertSubscriber sub3 = new AssertSubscriber<>(); + final AssertSubscriber sub4 = new AssertSubscriber<>(); + + reconnectMono.subscribe(sub1); + reconnectMono.subscribe(sub2); + reconnectMono.subscribe(sub3); + reconnectMono.subscribe(sub4); + + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(4); + + final ArrayList> subscribers = new ArrayList<>(200); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final AssertSubscriber subA = new AssertSubscriber<>(); + final AssertSubscriber subB = new AssertSubscriber<>(); + subscribers.add(subA); + subscribers.add(subB); + RaceTestUtils.race(() -> reconnectMono.subscribe(subA), () -> reconnectMono.subscribe(subB)); + } + + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 4); + + sub1.cancel(); + + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 3); + + publisher.next("value"); + + assertThat(sub1.scan(Scannable.Attr.CANCELLED)).isTrue(); + assertThat(sub2.values().get(0)).isEqualTo("value"); + assertThat(sub3.values().get(0)).isEqualTo("value"); + assertThat(sub4.values().get(0)).isEqualTo("value"); + + for (AssertSubscriber sub : subscribers) { + assertThat(sub.values().get(0)).isEqualTo("value"); + assertThat(sub.isTerminated()).isTrue(); + } + + assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value"); + cold.complete(); + final int timeout = 10; + + final ReconnectMono reconnectMono = + cold.flux() + .takeLast(1) + .next() + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::invalidate); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + cold.next("value2"); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectNext("value2") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received) + .hasSize(2) + .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value2", reconnectMono)); + + assertThat(cold.subscribeCount()).isEqualTo(2); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value"); + final int timeout = 10000; + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::dispose); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectError(CancellationException.class) + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + assertThat(cold.subscribeCount()).isEqualTo(1); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldTimeoutRetryWithVirtualTime() { + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + // then + StepVerifier.withVirtualTime( + () -> + Mono.error(new RuntimeException("Something went wrong")) + .retryWhen( + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(minBackoff)) + .doAfterRetry(onRetry()) + .maxBackoff(Duration.ofSeconds(maxBackoff))) + .timeout(Duration.ofSeconds(timeout)) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())) + .subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .thenAwait(Duration.ofSeconds(timeout)) + .expectError(TimeoutException.class) + .verify(Duration.ofSeconds(timeout)); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + @Test + public void ensuresThatMainSubscriberAllowsOnlyTerminationWithValue() { + final int timeout = 10; + final ReconnectMono reconnectMono = + new ReconnectMono<>(Mono.empty(), onExpire(), onValue()); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .hasMessage("Source completed empty") + .isInstanceOf(IllegalStateException.class)) + .verify(Duration.ofSeconds(timeout)); + } + + @Test + public void monoRetryNoBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.max(2).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.create(mono).verifyErrorMatches(Exceptions::isRetryExhausted); + assertRetries(IOException.class, IOException.class); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryFixedBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.fixedDelay(1, Duration.ofMillis(500)).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .expectNoEvent(Duration.ofMillis(300)) + .thenAwait(Duration.ofMillis(300)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryExponentialBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .jitter(0.0d) + .doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .thenAwait(Duration.ofMillis(100)) + .thenAwait(Duration.ofMillis(200)) + .thenAwait(Duration.ofMillis(400)) + .thenAwait(Duration.ofMillis(500)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class, IOException.class, IOException.class, IOException.class); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + Consumer onRetry() { + return context -> retries.add(context); + } + + BiConsumer onValue() { + return (v, __) -> received.add(Tuples.of(v, __)); + } + + Consumer onExpire() { + return (v) -> expired.add(v); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertThat(retries.size()).isEqualTo(exceptions.length); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); + index++; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java new file mode 100644 index 000000000..c1e0a6876 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java @@ -0,0 +1,845 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.CANCEL; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(10); + + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + stateAssert.hasSubscribedFlag().hasRequestN(10).hasNoFirstFrameSentFlag(); + + publisher.assertMaxRequested(1).next(payload); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(10).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(10) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(11).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelRequesterFlux.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldErrorWithoutInitializingRemoteStreamIfSourceIsEmpty(boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.complete(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Empty Source"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldPropagateErrorWithoutInitializingRemoteStreamIfTheFirstSignalIsError( + boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.error(new RuntimeException("test")); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(RuntimeException.class) + .assertErrorMessage("test"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + requestChannelRequesterFlux.handleCancel(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasOutboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleComplete(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasInboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleCancel(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void errorShouldTerminateExecution(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + publisher.error(new ApplicationErrorException("test")); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("test") + .hasNoLeaks(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleError(new ApplicationErrorException("test")); + publisher.assertWasCancelled(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + stateAssert.hasSubscribedFlag().hasRequestN(1).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(1) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + publisher.assertWasCancelled(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelRequesterFlux.handlePayload(responsePayload1); + requestChannelRequesterFlux.handlePayload(responsePayload2); + requestChannelRequesterFlux.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelRequesterFlux.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelRequesterFlux.handleComplete(); + } else { + requestChannelRequesterFlux.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.pollFrame(); + if (errorFrameOrEmpty != null) { + if (outboundTerminationMode.equals("onError")) { + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData("outboundException") + .hasNoLeaks(); + } else { + FrameAssert.assertThat(errorFrameOrEmpty).typeOf(FrameType.CANCEL).hasNoLeaks(); + } + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + PayloadAssert.assertThat(signal.get()) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + Assertions.assertThat(throwable).isEqualTo(inboundException); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + Assertions.assertThat(throwable) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else if (inboundTerminationMode.equals("complete")) { + if (signal.isOnComplete()) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).isEmpty(); + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the error signal[%s] is the last signal, but the last was %s", + j, values.size() - 1) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"complete", "cancel"}) + public void shouldRemoveItselfFromActiveStreamsWhenInboundAndOutboundAreTerminated( + String outboundTerminationMode) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("cancel")) { + requestChannelRequesterFlux.handleCancel(); + } else { + publisher.complete(); + } + }, + requestChannelRequesterFlux::handleComplete); + + ByteBuf completeFrameOrNull = sender.pollFrame(); + if (completeFrameOrNull != null) { + FrameAssert.assertThat(completeFrameOrNull) + .hasStreamId(1) + .typeOf(FrameType.COMPLETE) + .hasNoLeaks(); + } + + assertSubscriber.assertTerminated().assertComplete(); + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java new file mode 100644 index 000000000..890458caf --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java @@ -0,0 +1,890 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.*; +import static reactor.test.publisher.TestPublisher.Violation.CLEANUP_ON_TERMINATE; +import static reactor.test.publisher.TestPublisher.Violation.DEFER_CANCELLATION; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Exceptions; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelResponderSubscriberTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound", "inboundCancel"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + publisher.next(TestRequesterResponderSupport.genericPayload(allocator)); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelResponderSubscriber.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelResponderSubscriber.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("inboundCancel")) { + assertSubscriber.cancel(); + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }); + + FrameAssert.assertThat(sender.awaitFrame()).typeOf(CANCEL).hasStreamId(1).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(nextPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void failOnOverflowBeforeFirstPayloadIsSent() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(1); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCompleteWithSubscription() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> + requestChannelResponderSubscriber + .doOnNext(__ -> assertSubscriber.request(1)) + .subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleComplete()); + + stateAssert + .hasSubscribedFlag() + .hasInboundTerminated() + .hasFirstFrameSentFlag() + .hasRequestNBetween(1, 2); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertTerminated() + .assertComplete(); + + publisher.complete(); + + if (sender.getSent().size() > 1) { + FrameAssert.assertThat(sender.awaitFrame()) + .hasStreamId(1) + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasNoLeaks(); + } + FrameAssert.assertThat(sender.awaitFrame()).hasStreamId(1).typeOf(COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { + ApplicationErrorException applicationErrorException = new ApplicationErrorException("test"); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleError(applicationErrorException)); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(applicationErrorException.getClass()) + .assertErrorMessage("test"); + + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingOutboundErrorWithSubscription() { + RuntimeException exception = new RuntimeException("test"); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> publisher.error(exception)); + + stateAssert.isTerminated(); + + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(ERROR) + .hasData("test") + .hasStreamId(1) + .hasNoLeaks(); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCancelWithSubscription() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleCancel()); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Inbound has been canceled"); + + allocator.assertHasNoLeaks(); + } + } + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, requestPayload, activeStreams); + + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + final AssertSubscriber> assertSubscriber = + requestChannelResponderSubscriber + .materialize() + .subscribeWith(AssertSubscriber.create(0)); + + assertSubscriber.request(Integer.MAX_VALUE); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelResponderSubscriber.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelResponderSubscriber.handlePayload(responsePayload1); + requestChannelResponderSubscriber.handlePayload(responsePayload2); + requestChannelResponderSubscriber.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelResponderSubscriber.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelResponderSubscriber.handleComplete(); + } else { + requestChannelResponderSubscriber.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.pollFrame(); + if (errorFrameOrEmpty != null) { + String message; + if (outboundTerminationMode.equals("onError")) { + message = outboundException.getMessage(); + } else if (outboundTerminationMode.equals("sizeError")) { + message = String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK); + } else { + message = "Failed to validate payload. Cause:refCnt: 0"; + } + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData(message) + .hasNoLeaks(); + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + Payload payload = signal.get(); + if (j == 0) { + Assertions.assertThat(payload).isEqualTo(requestPayload); + } + + PayloadAssert.assertThat(payload) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + inboundException.getMessage(), + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors) + .hasSize(1) + .first() + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + } + } + } else if (inboundTerminationMode.equals("complete")) { + Assertions.assertThat(droppedErrors).isEmpty(); + if (signal.isOnError()) { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf(CancellationException.class) + .matches( + t -> + t.getMessage().equals("Inbound has been canceled") + || t.getMessage().equals("Outbound has terminated with an error")); + } + } else { + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + "Inbound has been canceled", + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + Assertions.assertThat(throwable).isExactlyInstanceOf(CancellationException.class); + } + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the %s signal[%s] is the last signal, but the last was %s", + signal, j, values.get(values.size() - 1)) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"onError", "sizeError", "refCntError", "cancel"}) + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(String terminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(2); + + Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final RequestChannelResponderSubscriber requestOperator = + new RequestChannelResponderSubscriber(1, Long.MAX_VALUE, firstPayload, activeStreams); + + publisher.subscribe(requestOperator); + requestOperator.subscribe(assertSubscriber); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, responsePayload); + + Payload releasedPayload1 = ByteBufPayload.create(new byte[0]); + Payload releasedPayload2 = ByteBufPayload.create(new byte[0]); + releasedPayload1.release(); + releasedPayload2.release(); + + RaceTestUtils.race( + () -> { + switch (terminationMode) { + case "onError": + publisher.error(outboundException); + break; + case "sizeError": + publisher.next(oversizePayload); + break; + case "refCntError": + publisher.next(releasedPayload1); + break; + case "cancel": + default: + assertSubscriber.cancel(); + } + }, + () -> { + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + requestOperator.handleNext(frame, lastFragmentId != j, false); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + + PayloadAssert.assertThat(values.get(0)).isEqualTo(firstPayload).hasNoLeaks(); + + if (values.size() > 1) { + Payload payload = values.get(1); + PayloadAssert.assertThat(payload).isEqualTo(responsePayload).hasNoLeaks(); + } + + if (!sender.isEmpty()) { + if (terminationMode.equals("cancel")) { + assertSubscriber.assertNotTerminated(); + } else { + assertSubscriber.assertTerminated().assertError(); + } + + final ByteBuf requstFrame = sender.awaitFrame(); + FrameAssert.assertThat(requstFrame) + .isNotNull() + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf terminalFrame = sender.awaitFrame(); + FrameAssert.assertThat(terminalFrame) + .isNotNull() + .typeOf(terminationMode.equals("cancel") ? CANCEL : ERROR) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + + PayloadAssert.assertThat(responsePayload).hasNoLeaks(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java new file mode 100644 index 000000000..b39ac62d9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java @@ -0,0 +1,698 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestResponseRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + * + */ + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends frame on the first request 3) + * terminates up on receiving the first signal (terminates on first next | error | next over + * reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(RequestResponseRequesterMono.STATE, requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + PayloadAssert.assertThat(payload).isReleased(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + stateAssert.isTerminated(); + + if (!sender.isEmpty()) { + ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnSubscriptionResponses() { + return Stream.of( + // next case + (rrm, sv) -> + sv.then(() -> rrm.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .expectComplete(), + // complete case + (rrm, sv) -> sv.then(rrm::handleComplete).expectComplete(), + // error case + (rrm, sv) -> + sv.then(() -> rrm.handleError(new ApplicationErrorException("test"))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + // fragmentation case + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then(stateAssert::isTerminated) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + p.release(); + }) + .then(payload::release) + .expectComplete(); + }, + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then( + () -> { + rrm.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends fragments frames on the first + * request 3) terminates up on receiving the first signal (terminates on first next | error | next + * over reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameFragmentsShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General StateMachine transition test. Ensures that no fragment is sent if mono was cancelled + * before any requests + */ + @Test + public void shouldBeNoOpsOnCancel() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(() -> stateAssert.hasSubscribedFlagOnly()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenCancel() + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload is an invalid one. + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestResponseRequesterMono); + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is disabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload payload = ByteBufPayload.create(""); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is enabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrates + * to the terminated in case the given payload is too big with disabled fragmentation + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that error check happens exactly before frame sent. This cases ensures that in case no + * lease / other external errors appeared, the local subscriber received the same one. No frames + * should be sent + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then(() -> StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestResponseRequesterMono).name()) + .isEqualTo("source(RequestResponseMono)"); + requestResponseRequesterMono.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java new file mode 100644 index 000000000..8702d1a80 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java @@ -0,0 +1,1227 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestStreamRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * REQUESTED(0) -> REQUESTED(1) -> REQUESTED(0)
+   * REQUESTED(0) -> REQUESTED(MAX)
+   * REQUESTED(MAX) -> REQUESTED(MAX) && REASSEMBLY (extra flag enabled which indicates
+   * reassembly)
+   * REQUESTED(MAX) && REASSEMBLY -> TERMINATED
+   * 
+ */ + @Test + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + ByteBuf firstFragment = fragments.remove(0); + requestStreamRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollowing = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestStreamRequesterFlux.handleNext(followingFragment, hasFollowing, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + Payload finalRandomPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(finalRandomPayload); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(finalRandomPayload).hasNoLeaks()) + .assertComplete(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(MAX)
+   * REQUESTED(MAX) -> TERMINATED
+   * 
+ */ + @Test + public void requestNFrameShouldBeSentExactlyOnceIfItIsMaxAllowed() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Long.MAX_VALUE / 2 + 1); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + requestStreamRequesterFlux.handlePayload(EmptyPayload.INSTANCE); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber.assertValues(EmptyPayload.INSTANCE).assertComplete(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * 
+ * + * And then for the following cases: + * + *
+   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
+   * affect state anyhow and should not sent any extra frames)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [3]: REQUESTED(0) -> REASSEMBLY
+   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [4]: REQUESTED(0) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
+   * 
+ */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnFirstRequestResponses() { + return Stream.of( + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handleError(new ApplicationErrorException("test"))) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .thenRequest(1L) + .thenRequest(1L) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + final Payload payload2 = ByteBufPayload.create(data, metadata); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .then(payload::release) + .then(() -> rsf.handlePayload(payload2)) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(); + }, + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload0 = ByteBufPayload.create(data, metadata); + final Payload payload = ByteBufPayload.create(data, metadata); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then(() -> rsf.handlePayload(payload0)) + .assertNext(p -> PayloadAssert.assertThat(p).isEqualTo(payload0).hasNoLeaks()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + // state machine check + StateAssert.assertThat(rsf).isTerminated(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * State Machine check with fragmentation of the first payload. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * 
+ * + * And then for the following cases: + * + *
+   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
+   * affect state anyhow and should not sent any extra frames)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [3]: REQUESTED(0) -> REASSEMBLY
+   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [4]: REQUESTED(0) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
+   * 
+ */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameFragmentsShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N) + // InitialRequestN size + .hasMetadata(Arrays.copyOf(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA) + .hasMetadata( + Arrays.copyOfRange(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N, 65)) + .hasData(Arrays.copyOf(data, 35)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 35, 35 + 55)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(39) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 90, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Case which ensures that if Payload has incorrect refCnt, the flux ends up with an appropriate + * error + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will exponse + * the error immediatelly and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final Payload payload = ByteBufPayload.create(""); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will expose + * the error immediately and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if the given payload is exits 16mb size with disabled fragmentation, than the + * appropriate validation happens and a corresponding error will be propagagted to the subscriber + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that the interactions check and respect rsocket availability (such as leasing) and + * propagate an error to the final subscriber. No frame should be sent. Check should happens + * exactly on the first request. + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + Payload requestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(requestedPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isEqualTo(requestedPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + PayloadAssert.assertThat(requestedPayload).isReleased(); + PayloadAssert.assertThat(unrequestedPayload).isReleased(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestStreamRequesterFlux).name()) + .isEqualTo("source(RequestStreamFlux)"); + requestStreamRequesterFlux.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java new file mode 100644 index 000000000..06d050f6f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java @@ -0,0 +1,790 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +@SuppressWarnings("ALL") +public class RequesterOperatorsRacingTest { + + interface Scenario { + FrameType requestType(); + + Publisher requestOperator( + Supplier payloadsSupplier, RequesterResponderSupport requesterResponderSupport); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return METADATA_PUSH; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new MetadataPushRequesterMono(payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return MetadataPushRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_FNF; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new FireAndForgetRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return FireAndForgetRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestResponseRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestStreamRequesterFlux( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestStreamRequesterFlux.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestChannelRequesterFlux( + Flux.generate(s -> s.next(payloadsSupplier.get())), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestChannelRequesterFlux.class.getSimpleName(); + } + }); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + /** Ensures single subscription happens in case of racing */ + @ParameterizedTest(name = "Should subscribe exactly once to {0}") + @MethodSource("scenarios") + public void shouldSubscribeExactlyOnce(Scenario scenario) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport requesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> + TestRequesterResponderSupport.genericPayload( + requesterResponderSupport.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, requesterResponderSupport); + + StepVerifier stepVerifier = + StepVerifier.create(requesterResponderSupport.getDuplexConnection().getSentAsPublisher()) + .assertNext( + frame -> { + FrameAssert frameAssert = + FrameAssert.assertThat(frame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()); + if (scenario.requestType() == METADATA_PUSH) { + frameAssert + .hasStreamIdZero() + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT); + } else { + frameAssert + .hasClientSideStreamId() + .hasStreamId(1) + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length + + TestRequesterResponderSupport.DATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT); + } + frameAssert.hasNoLeaks(); + + if (requestOperator instanceof FrameHandler) { + ((FrameHandler) requestOperator).handleComplete(); + if (scenario.requestType() == REQUEST_CHANNEL) { + ((FrameHandler) requestOperator).handleCancel(); + } + } + }) + .thenCancel() + .verifyLater(); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + }, + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + })) + .matches( + t -> { + Assertions.assertThat(t).hasMessageContaining("allows only a single Subscriber"); + return true; + }); + + stepVerifier.verify(Duration.ofSeconds(1)); + requesterResponderSupport.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } + } + } + + /** Ensures single frame is sent only once racing between requests */ + @ParameterizedTest(name = "{0} should sent requestFrame exactly once if request(n) is racing") + @MethodSource("scenarios") + public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + requestOperator.subscribe(assertSubscriber); + + RaceTestUtils.race(() -> assertSubscriber.request(1), () -> assertSubscriber.request(1)); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + + if (scenario.requestType().hasInitialRequestN()) { + if (RequestStreamFrameCodec.initialRequestN(sentFrame) == 1) { + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .isNotNull() + .hasStreamId(1) + .hasRequestN(1) + .typeOf(REQUEST_N) + .hasNoLeaks(); + } else { + Assertions.assertThat(RequestStreamFrameCodec.initialRequestN(sentFrame)).isEqualTo(2); + } + } + + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + ((RequesterFrameHandler) requestOperator).handlePayload(response); + ((RequesterFrameHandler) requestOperator).handleComplete(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + ((CoreSubscriber) requestOperator).onComplete(); + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + } + + assertSubscriber + .assertTerminated() + .assertValuesWith( + p -> { + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + } + } + + /** + * Ensures that no ByteBuf is leaked if reassembly is starting and cancel is happening at the same + * time + */ + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(1); + + requestOperator.subscribe(assertSubscriber); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = + TestRequesterResponderSupport.randomPayload(activeStreams.getAllocator()); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments( + activeStreams.getAllocator(), mtu, responsePayload); + RaceTestUtils.race( + assertSubscriber::cancel, + () -> { + FrameHandler frameHandler = (FrameHandler) requestOperator; + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + frameHandler.handleNext(frame, lastFragmentId != j, lastFragmentId == j); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + if (!values.isEmpty()) { + Assertions.assertThat(values) + .hasSize(1) + .first() + .matches( + p -> { + Assertions.assertThat(p.sliceData()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceData())); + Assertions.assertThat(p.hasMetadata()).isEqualTo(responsePayload.hasMetadata()); + Assertions.assertThat(p.sliceMetadata()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceMetadata())); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + return true; + }); + } + + if (!activeStreams.getDuplexConnection().isEmpty()) { + if (scenario.requestType() != REQUEST_CHANNEL) { + assertSubscriber.assertNotTerminated(); + } + + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + + Assertions.assertThat(responsePayload.release()).isTrue(); + Assertions.assertThat(responsePayload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case of racing between next element and cancel we will not have any memory + * leaks + */ + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnNextAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handlePayload(response)); + + assertSubscriber.values().forEach(Payload::release); + Assertions.assertThat(response.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + assertSubscriber.assertTerminated(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case we have element reassembling and then it happens the remote sends + * (errorFrame) and downstream subscriber sends cancel() and we have racing between onError and + * cancel we will not have any memory leaks + */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + boolean[] withReassemblyOptions = new boolean[] {true, false}; + final ArrayList droppedErrors = new ArrayList<>(); + Hooks.onErrorDropped(droppedErrors::add); + + try { + for (boolean withReassembly : withReassemblyOptions) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + final StateAssert stateAssert; + if (requestOperator instanceof RequestResponseRequesterMono) { + stateAssert = StateAssert.assertThat((RequestResponseRequesterMono) requestOperator); + } else if (requestOperator instanceof RequestStreamRequesterFlux) { + stateAssert = StateAssert.assertThat((RequestStreamRequesterFlux) requestOperator); + } else { + stateAssert = StateAssert.assertThat((RequestChannelRequesterFlux) requestOperator); + } + + stateAssert.isUnsubscribed(); + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (withReassembly) { + final ByteBuf fragmentBuf = + activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); + ((RequesterFrameHandler) requestOperator).handleNext(fragmentBuf, true, false); + // mimic frameHandler behaviour + fragmentBuf.release(); + } + + final RuntimeException testException = new RuntimeException("test"); + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handleError(testException)); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(droppedErrors).containsExactly(testException); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnError(1) + .expectNothing(); + + assertSubscriber.assertTerminated().assertErrorMessage("test"); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + + stateAssert.isTerminated(); + droppedErrors.clear(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + /** + * Ensures that in case of racing between first request and cancel does not going to introduce + * leaks.
+ *
+ * + *

Please note, first request may or may not happen so in case it happened before cancellation + * signal we have to observe + * + *

    + *
  • RequestResponseFrame + *
  • CancellationFrame + *
+ * + *

exactly in that order + * + *

Ensures full serialization of outgoing signal (frames) + */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + RaceTestUtils.race(() -> assertSubscriber.cancel(), () -> assertSubscriber.request(1)); + + if (!activeStreams.getDuplexConnection().isEmpty()) { + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .typeOf(scenario.requestType()) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } + + ((RequesterFrameHandler) requestOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); + + Assertions.assertThat(response.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** Ensures that CancelFrame is sent exactly once in case of racing between cancel() methods */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldSentCancelFrameExactlyOnce(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requesterOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requesterOperator.subscribe((AssertSubscriber) assertSubscriber); + + assertSubscriber.request(1); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + ((Subscription) requesterOperator)::cancel, ((Subscription) requesterOperator)::cancel); + + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + + activeStreams.assertNoActiveStreams(); + + ((RequesterFrameHandler) requesterOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); + Assertions.assertThat(response.refCnt()).isZero(); + + ((RequesterFrameHandler) requesterOperator).handleComplete(); + assertSubscriber.assertNotTerminated(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java new file mode 100644 index 000000000..382240c4a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java @@ -0,0 +1,1030 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Condition; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class ResolvingOperatorTests { + + @Test + public void shouldExpireValueOnRacingDisposeAndComplete() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .thenAddObserver(consumer) + .assertPendingSubscribers(1) + .assertPendingResolution() + .then(self -> RaceTestUtils.race(() -> self.complete("value" + index), self::dispose)) + .assertDisposeCalled(1) + .assertExpiredExactly("value" + index) + .ifResolvedAssertEqual("value" + index) + .assertIsDisposed(); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Assertions.assertThat(subscriber.errors().get(0)) + .isInstanceOf(CancellationException.class) + .hasMessage("Disposed"); + + } else { + Assertions.assertThat(subscriber.values()).containsExactly("value" + i); + } + } + } + + @Test + public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .then( + self -> { + RaceTestUtils.race(() -> self.complete(valueToSend), () -> self.observe(consumer)); + + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + }) + .assertDisposeCalled(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .thenAddObserver(consumer2) + .assertPendingSubscribers(0); + + subscriber2.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + final String valueToSend2 = "value2" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .then( + self -> { + self.complete(valueToSend); + + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + }) + .assertReceivedExactly(valueToSend) + .then( + self -> + RaceTestUtils.race( + self::invalidate, + () -> { + self.observe(consumer2); + if (!subscriber2.isTerminated()) { + self.complete(valueToSend2); + } + })) + .then( + self -> { + if (self.isPending()) { + self.assertReceivedExactly(valueToSend); + } else { + self.assertReceivedExactly(valueToSend, valueToSend2); + } + }) + .assertExpiredExactly(valueToSend) + .assertPendingSubscribers(0) + .assertDisposeCalled(0) + .then( + self -> + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { + if (self.subscribers == ResolvingOperator.READY) { + Assertions.assertThat(v).isEqualTo(valueToSend2); + } else { + Assertions.assertThat(v).isEqualTo(valueToSend); + } + }) + .assertComplete()); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + final String valueToSend2 = "value_to_possibly_expire" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .then( + self -> { + self.complete(valueToSend); + + subscriber.await(Duration.ofMillis(100)).assertValues(valueToSend).assertComplete(); + }) + .assertReceivedExactly(valueToSend) + .then( + self -> + RaceTestUtils.race( + self::invalidate, + self::invalidate, + () -> { + self.observe(consumer2); + if (!subscriber2.isTerminated()) { + self.complete(valueToSend2); + } + })) + .then( + self -> { + if (!self.isPending()) { + self.assertReceivedExactly(valueToSend, valueToSend2); + } else { + if (self.received.size() > 1) { + self.assertReceivedExactly(valueToSend, valueToSend2); + } else { + self.assertReceivedExactly(valueToSend); + } + } + + Assertions.assertThat(self.expired) + .haveAtMost( + 2, + new Condition<>( + new Predicate() { + int time = 0; + + @Override + public boolean test(Object s) { + if (time++ == 0) { + return valueToSend.equals(s); + } else { + return valueToSend2.equals(s); + } + } + }, + "should matches one of the given values")); + }) + .assertPendingSubscribers(0) + .assertDisposeCalled(0) + .then( + self -> + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { + if (self.subscribers == ResolvingOperator.READY) { + Assertions.assertThat(v).isEqualTo(valueToSend2); + } else { + Assertions.assertThat(v).isIn(valueToSend, valueToSend2); + } + }) + .assertComplete()); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + final String valueToSend2 = "value2" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .then( + self -> { + self.complete(valueToSend); + + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + }) + .assertReceivedExactly(valueToSend) + .then( + self -> + RaceTestUtils.race( + () -> + Assertions.assertThat(self.block(null)) + .matches((v) -> v.equals(valueToSend) || v.equals(valueToSend2)), + self::invalidate, + () -> { + for (; ; ) { + if (self.subscribers != ResolvingOperator.READY) { + self.complete(valueToSend2); + break; + } + } + })) + .then( + self -> { + if (self.isPending()) { + self.assertReceivedExactly(valueToSend); + } else { + self.assertReceivedExactly(valueToSend, valueToSend2); + } + }) + .assertExpiredExactly(valueToSend) + .assertPendingSubscribers(0) + .assertDisposeCalled(0); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .then( + self -> + RaceTestUtils.race(() -> self.observe(consumer), () -> self.observe(consumer2))) + .assertSubscribeCalled(1) + .assertPendingSubscribers(2) + .then(self -> self.complete(valueToSend)) + .assertPendingSubscribers(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .assertDisposeCalled(0) + .then( + self -> { + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); + + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); + + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); + + Assertions.assertThat(self.add(consumer)).isEqualTo(ResolvingOperator.READY_STATE); + }); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .whenSubscribe(self -> self.complete(valueToSend)) + .then( + self -> + RaceTestUtils.race( + () -> { + subscriber.onNext(self.block(null)); + subscriber.onComplete(); + }, + () -> self.observe(consumer2))) + .assertSubscribeCalled(1) + .assertPendingSubscribers(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .assertDisposeCalled(0) + .then( + self -> { + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); + + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); + + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); + + Assertions.assertThat(self.add(consumer2)).isEqualTo(ResolvingOperator.READY_STATE); + }); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .whenSubscribe(self -> self.complete(valueToSend)) + .then( + self -> + RaceTestUtils.race( + () -> { + subscriber.onNext(self.block(timeout)); + subscriber.onComplete(); + }, + () -> { + subscriber2.onNext(self.block(timeout)); + subscriber2.onComplete(); + })) + .assertSubscribeCalled(1) + .assertPendingSubscribers(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .assertDisposeCalled(0) + .then( + self -> { + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); + + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); + + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); + + Assertions.assertThat(self.add((v, t) -> {})) + .isEqualTo(ResolvingOperator.READY_STATE); + }); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndError() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .assertSubscribeCalled(1) + .assertPendingSubscribers(1) + .then(self -> RaceTestUtils.race(() -> self.terminate(runtimeException), self::dispose)) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertDisposeCalled(1) + .then( + self -> { + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.TERMINATED); + + Assertions.assertThat(self.add((v, t) -> {})) + .isEqualTo(ResolvingOperator.TERMINATED_STATE); + }) + .thenAddObserver(consumer2); + + subscriber + .await(Duration.ofMillis(10)) + .assertErrorWith( + t -> { + if (t instanceof CancellationException) { + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Disposed"); + } else { + Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); + } + }); + + subscriber2 + .await(Duration.ofMillis(10)) + .assertErrorWith( + t -> { + if (t instanceof CancellationException) { + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Disposed"); + } else { + Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); + } + }); + + // no way to guarantee equality because of racing + // Assertions.assertThat(processor.getError()) + // .isEqualTo(processor2.getError()); + } + } + + @Test + public void shouldThrowOnBlocking() { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .then( + self -> + Assertions.assertThatThrownBy(() -> self.block(Duration.ofMillis(100))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on Mono blocking read")) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertNothingReceived() + .assertDisposeCalled(0); + } + + @Test + public void shouldThrowOnBlockingIfHasAlreadyTerminated() { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .whenSubscribe(self -> self.terminate(new RuntimeException("test"))) + .then( + self -> + Assertions.assertThatThrownBy(() -> self.block(Duration.ofMillis(100))) + .isInstanceOf(RuntimeException.class) + .hasMessage("test") + .hasSuppressedException(new Exception("Terminated with an error"))) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertNothingReceived() + .assertDisposeCalled(1); + } + + static Stream, Publisher>> innerCases() { + return Stream.of( + (self) -> { + final Sinks.One processor = Sinks.unsafe().one(); + final ResolvingOperator.DeferredResolution operator = + new ResolvingOperator.DeferredResolution( + self, new SinkOneSubscriber(processor)) { + @Override + public void accept(String v, Throwable t) { + if (t != null) { + onError(t); + return; + } + + onNext(v); + } + }; + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); + }, + (self) -> { + final Sinks.One processor = Sinks.unsafe().one(); + final SinkOneSubscriber subscriber = new SinkOneSubscriber(processor); + final ResolvingOperator.MonoDeferredResolutionOperator operator = + new ResolvingOperator.MonoDeferredResolutionOperator<>(self, subscriber); + subscriber.onSubscribe(operator); + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); + }); + } + + @ParameterizedTest + @MethodSource("innerCases") + public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest( + Function, Publisher> caseProducer) { + ResolvingTest.create() + .then( + self -> { + Publisher resolvingInner = caseProducer.apply(self); + StepVerifier.create(resolvingInner) + .expectSubscription() + .then(() -> self.assertSubscribeCalled(1).assertPendingSubscribers(1)) + .thenCancel() + .verify(Duration.ofMillis(100)); + }) + .assertPendingSubscribers(0) + .assertNothingExpired() + .then(self -> self.complete("test")) + .assertReceivedExactly("test"); + } + + @ParameterizedTest + @MethodSource("innerCases") + public void shouldExpireValueOnDispose( + Function, Publisher> caseProducer) { + ResolvingTest.create() + .then( + self -> { + Publisher resolvingInner = caseProducer.apply(self); + + StepVerifier.create(resolvingInner) + .expectSubscription() + .then(() -> self.complete("test")) + .expectNext("test") + .expectComplete() + .verify(Duration.ofMillis(100)); + }) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertReceivedExactly("test") + .then(ResolvingOperator::dispose) + .assertExpiredExactly("test") + .assertDisposeCalled(1); + } + + @ParameterizedTest + @MethodSource("innerCases") + public void shouldNotifyAllTheSubscribers( + Function, Publisher> caseProducer) { + + AssertSubscriber sub1 = AssertSubscriber.create(); + AssertSubscriber sub2 = AssertSubscriber.create(); + AssertSubscriber sub3 = AssertSubscriber.create(); + AssertSubscriber sub4 = AssertSubscriber.create(); + + final ArrayList> processors = + new ArrayList<>(RaceTestConstants.REPEATS * 2); + + ResolvingTest.create() + .assertDisposeCalled(0) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .then( + self -> { + caseProducer.apply(self).subscribe(sub1); + caseProducer.apply(self).subscribe(sub2); + caseProducer.apply(self).subscribe(sub3); + caseProducer.apply(self).subscribe(sub4); + }) + .assertSubscribeCalled(1) + .assertPendingSubscribers(4) + .then( + self -> { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subA = AssertSubscriber.create(); + AssertSubscriber subB = AssertSubscriber.create(); + processors.add(subA); + processors.add(subB); + RaceTestUtils.race( + () -> caseProducer.apply(self).subscribe(subA), + () -> caseProducer.apply(self).subscribe(subB)); + } + }) + .assertSubscribeCalled(1) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 4) + .then(self -> sub1.cancel()) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 3) + .then( + self -> { + String valueToSend = "value"; + self.complete(valueToSend); + + Assertions.assertThat(sub1.isTerminated()).isFalse(); + Assertions.assertThat(sub2.values()).containsExactly(valueToSend); + Assertions.assertThat(sub3.values()).containsExactly(valueToSend); + Assertions.assertThat(sub4.values()).containsExactly(valueToSend); + + for (AssertSubscriber sub : processors) { + Assertions.assertThat(sub.values()).containsExactly(valueToSend); + Assertions.assertThat(sub.isTerminated()).isTrue(); + } + }) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertReceivedExactly("value"); + } + + @Test + public void shouldBeSerialIfRacyMonoInner() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + long[] requested = new long[] {0}; + Subscription mockSubscription = Mockito.mock(Subscription.class); + Mockito.doAnswer( + a -> { + long argument = a.getArgument(0); + return requested[0] += argument; + }) + .when(mockSubscription) + .request(Mockito.anyLong()); + ResolvingOperator.DeferredResolution resolution = + new ResolvingOperator.DeferredResolution( + ResolvingTest.create(), AssertSubscriber.create(0)) { + + @Override + public void accept(Object o, Object o2) {} + }; + + resolution.request(5); + + RaceTestUtils.race( + () -> resolution.onSubscribe(mockSubscription), + () -> { + resolution.request(10); + resolution.request(10); + resolution.request(10); + }); + + resolution.request(15); + + Assertions.assertThat(requested[0]).isEqualTo(50L); + } + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .then(self -> self.complete("test")) + .assertReceivedExactly("test") + .then(self -> RaceTestUtils.race(self::invalidate, self::invalidate)) + .assertExpiredExactly("test"); + } + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .then(self -> self.complete("test")) + .assertReceivedExactly("test") + .then(self -> RaceTestUtils.race(self::invalidate, self::dispose)) + .assertExpiredExactly("test"); + } + } + + static class ResolvingTest extends ResolvingOperator { + + final AtomicInteger subscribeCalls = new AtomicInteger(); + final AtomicInteger onDisposeCalls = new AtomicInteger(); + + final Queue received = new ConcurrentLinkedQueue<>(); + final Queue expired = new ConcurrentLinkedQueue<>(); + + Consumer> whenSubscribeConsumer = (self) -> {}; + + static ResolvingTest create() { + return new ResolvingTest<>(); + } + + public ResolvingTest assertPendingSubscribers(int cnt) { + Assertions.assertThat(this.subscribers.length).isEqualTo(cnt); + + return this; + } + + public ResolvingTest whenSubscribe(Consumer> consumer) { + this.whenSubscribeConsumer = consumer; + return this; + } + + public ResolvingTest then(Consumer> consumer) { + consumer.accept(this); + + return this; + } + + public ResolvingTest thenAddObserver(BiConsumer consumer) { + this.observe(consumer); + return this; + } + + public ResolvingTest assertPendingResolution() { + Assertions.assertThat(this.isPending()).isTrue(); + + return this; + } + + public ResolvingTest assertIsDisposed() { + Assertions.assertThat(this.isDisposed()).isTrue(); + + return this; + } + + public ResolvingTest assertSubscribeCalled(int times) { + Assertions.assertThat(subscribeCalls).hasValue(times); + + return this; + } + + public ResolvingTest assertDisposeCalled(int times) { + Assertions.assertThat(onDisposeCalls).hasValue(times); + return this; + } + + public ResolvingTest assertNothingExpired() { + return assertExpiredExactly(); + } + + public ResolvingTest assertExpiredExactly(T... values) { + Assertions.assertThat(expired).hasSize(values.length).containsExactly(values); + + return this; + } + + public ResolvingTest assertNothingReceived() { + return assertReceivedExactly(); + } + + public ResolvingTest assertReceivedExactly(T... values) { + Assertions.assertThat(received).hasSize(values.length).containsExactly(values); + + return this; + } + + public ResolvingTest ifResolvedAssertEqual(T value) { + if (received.size() > 0) { + Assertions.assertThat(received).hasSize(1).containsExactly(value); + } + + return this; + } + + @Override + protected void doOnValueResolved(T value) { + received.offer(value); + } + + @Override + protected void doOnValueExpired(T value) { + expired.offer(value); + } + + @Override + protected void doSubscribe() { + whenSubscribeConsumer.accept(this); + subscribeCalls.incrementAndGet(); + } + + @Override + protected void doOnDispose() { + onDisposeCalls.incrementAndGet(); + } + } + + private static class SinkOneSubscriber implements CoreSubscriber { + + private final Sinks.One processor; + private boolean valueReceived; + + public SinkOneSubscriber(Sinks.One processor) { + this.processor = processor; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String s) { + valueReceived = true; + processor.tryEmitValue(s); + } + + @Override + public void onError(Throwable t) { + processor.tryEmitError(t); + } + + @Override + public void onComplete() { + if (!valueReceived) { + processor.tryEmitEmpty(); + } + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java new file mode 100755 index 000000000..4f7821e4a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.NEXT; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; + +import io.netty.buffer.ByteBuf; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.test.publisher.TestPublisher; + +public class ResponderOperatorsCommonTest { + + interface Scenario { + FrameType requestType(); + + int maxElements(); + + ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler); + + ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public int maxElements() { + return 1; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber( + streamId, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + + return handler.requestResponse(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + + streamManager.activeStreams.put(streamId, subscriber); + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + + return handler.requestStream(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestStreamResponderSubscriber.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber responderSubscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstPayload, streamManager); + streamManager.activeStreams.put(streamId, responderSubscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + + return handler.requestChannel(responderSubscriber).subscribeWith(responderSubscriber); + } + + @Override + public String toString() { + return RequestChannelResponderSubscriber.class.getSimpleName(); + } + }); + } + + static class TestHandler implements RSocket { + + final TestPublisher producer; + final AssertSubscriber consumer; + + TestHandler(TestPublisher producer, AssertSubscriber consumer) { + this.producer = producer; + this.consumer = consumer; + } + + @Override + public Mono fireAndForget(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono().then(); + } + + @Override + public Mono requestResponse(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(consumer); + return producer.flux(); + } + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, + TestRequesterResponderSupport.genericPayload(allocator), + testRequesterResponderSupport, + testHandler); + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + responderFrameHandler.handleComplete(); + testHandler.consumer.assertComplete(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith(p -> PayloadAssert.assertThat(p).hasNoLeaks()); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleFragmentedRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + responderFrameHandler.handleNext(fragment, hasFollows, !hasFollows); + fragment.release(); + } + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.pollFrame()).isNull(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith( + p -> PayloadAssert.assertThat(p).hasData(firstPayload.sliceData()).hasNoLeaks()) + .assertComplete(); + + testRequesterResponderSupport.assertNoActiveStreams(); + + firstPayload.release(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleInterruptedFragmentation(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + firstPayload.release(); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + if (hasFollows) { + responderFrameHandler.handleNext(fragment, true, false); + } else { + responderFrameHandler.handleCancel(); + } + fragment.release(); + } + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertNoActiveStreams(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java new file mode 100644 index 000000000..9a51b9419 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java @@ -0,0 +1,31 @@ +package io.rsocket.core; + +import static org.mockito.Mockito.*; + +import io.netty.util.ReferenceCounted; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; + +public class SendUtilsTest { + + @Test + void droppedElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer value = extractDroppedElementConsumer(); + value.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = mock(ReferenceCounted.class); + when(referenceCounted.release()).thenReturn(true); + + Consumer value = extractDroppedElementConsumer(); + value.accept(referenceCounted); + + verify(referenceCounted).release(); + } + + private static Consumer extractDroppedElementConsumer() { + return (Consumer) SendUtils.DISCARD_CONTEXT.stream().findAny().get().getValue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java new file mode 100644 index 000000000..87c3a865f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -0,0 +1,209 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.transport.ServerTransport.ConnectionAcceptor; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.transport.ServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +public class SetupRejectionTest { + + @Test + void responderRejectSetup() { + SingleConnectionTransport transport = new SingleConnectionTransport(); + + String errorMsg = "error"; + RejectingAcceptor acceptor = new RejectingAcceptor(errorMsg); + RSocketServer.create().acceptor(acceptor).bind(transport).block(); + + transport.connect(); + + ByteBuf sentFrame = transport.awaitSent(); + assertThat(FrameHeaderCodec.frameType(sentFrame)).isEqualTo(FrameType.ERROR); + RuntimeException error = Exceptions.from(0, sentFrame); + sentFrame.release(); + assertThat(errorMsg).isEqualTo(error.getMessage()); + assertThat(error).isInstanceOf(RejectedSetupException.class); + RSocket acceptorSender = acceptor.senderRSocket().block(); + assertThat(acceptorSender.isDisposed()).isTrue(); + transport.allocator.assertHasNoLeaks(); + } + + @Test + void requesterStreamsTerminatedOnZeroErrorFrame() { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); + + RSocketRequester rSocket = + new RSocketRequester( + conn, + DefaultPayload::create, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); + + String errorMsg = "error"; + + StepVerifier.create( + rSocket + .requestResponse(DefaultPayload.create("test")) + .doOnRequest( + ignored -> + conn.addToReceivedBuffer( + ErrorFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 0, + new RejectedSetupException(errorMsg))))) + .expectErrorMatches( + err -> err instanceof RejectedSetupException && errorMsg.equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + + assertThat(rSocket.isDisposed()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + void requesterNewStreamsTerminatedAfterZeroErrorFrame() { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); + RSocketRequester rSocket = + new RSocketRequester( + conn, + DefaultPayload::create, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); + + conn.addToReceivedBuffer( + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); + + StepVerifier.create( + rSocket + .requestResponse(DefaultPayload.create("test")) + .delaySubscription(Duration.ofMillis(100))) + .expectErrorMatches( + err -> err instanceof RejectedSetupException && "error".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + allocator.assertHasNoLeaks(); + } + + private static class RejectingAcceptor implements SocketAcceptor { + private final String errorMessage; + private final Sinks.Many senderRSockets = + Sinks.many().unicast().onBackpressureBuffer(); + + public RejectingAcceptor(String errorMessage) { + this.errorMessage = errorMessage; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + senderRSockets.tryEmitNext(sendingSocket); + return Mono.error(new RuntimeException(errorMessage)); + } + + public Mono senderRSocket() { + return senderRSockets.asFlux().next(); + } + } + + private static class SingleConnectionTransport implements ServerTransport { + + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection conn = new TestDuplexConnection(allocator); + + @Override + public Mono start(ConnectionAcceptor acceptor) { + return Mono.just(new TestCloseable(acceptor, conn)); + } + + public ByteBuf awaitSent() { + return conn.awaitFrame(); + } + + public void connect() { + Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); + ByteBuf setup = SetupFrameCodec.encode(allocator, false, 0, 42, "mdMime", "dMime", payload); + + conn.addToReceivedBuffer(setup); + } + } + + private static class TestCloseable implements Closeable { + + private final DuplexConnection conn; + + TestCloseable(ConnectionAcceptor acceptor, DuplexConnection conn) { + this.conn = conn; + Mono.from(acceptor.apply(conn)).subscribe(notUsed -> {}, err -> conn.dispose()); + } + + @Override + public Mono onClose() { + return conn.onClose(); + } + + @Override + public void dispose() { + conn.dispose(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java new file mode 100644 index 000000000..88e0dc8e2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java @@ -0,0 +1,98 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%s)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_HAVE_FLAG = "Expected state\n\t%s\nto have\n\t%s\nbut had\n\t[%s]"; + + private ShouldHaveFlag(long currentState, String expectedFlag, String actualFlags) { + super(SHOULD_HAVE_FLAG, toBinaryString(currentState), expectedFlag, actualFlags); + } + + static ErrorMessageFactory shouldHaveFlag(long currentState, long expectedFlag) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag(currentState, FLAGS_NAMES.get(expectedFlag), stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestN(long currentState, long expectedRequestN) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + expectedRequestN == Integer.MAX_VALUE ? "MAX" : expectedRequestN), + stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestNBetween( + long currentState, long expectedRequestNMin, long expectedRequestNMax) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + (expectedRequestNMin == Integer.MAX_VALUE ? "MAX" : expectedRequestNMin) + + " - " + + (expectedRequestNMax == Integer.MAX_VALUE ? "MAX" : expectedRequestNMax)), + stateAsString); + } + + private static String extractStateAsString(long currentState) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append( + String.format( + FLAGS_NAMES.get(REQUEST_MASK), requestN >= Integer.MAX_VALUE ? "MAX" : requestN)); + } + return stringBuilder.toString(); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java new file mode 100644 index 000000000..e281e548c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java @@ -0,0 +1,73 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldNotHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%n)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_NOT_HAVE_FLAG = + "Expected state\n\t%s\nto not have\n\t%s\nbut had\n\t[%s]"; + + private ShouldNotHaveFlag(long currentState, long expectedFlag, String actualFlags) { + super( + SHOULD_NOT_HAVE_FLAG, + toBinaryString(currentState), + FLAGS_NAMES.get(expectedFlag), + actualFlags); + } + + static ErrorMessageFactory shouldNotHaveFlag(long currentState, long expectedFlag) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(String.format(FLAGS_NAMES.get(REQUEST_MASK), requestN)); + } + return new ShouldNotHaveFlag(currentState, expectedFlag, stringBuilder.toString()); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java new file mode 100644 index 000000000..64253984b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java @@ -0,0 +1,161 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.ShouldHaveFlag.*; +import static io.rsocket.core.ShouldNotHaveFlag.shouldNotHaveFlag; +import static io.rsocket.core.StateUtils.*; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.internal.Failures; + +public class StateAssert extends AbstractAssert, AtomicLongFieldUpdater> { + + public static StateAssert assertThat(AtomicLongFieldUpdater updater, T instance) { + return new StateAssert<>(updater, instance); + } + + public static StateAssert assertThat( + FireAndForgetRequesterMono instance) { + return new StateAssert<>(FireAndForgetRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestResponseRequesterMono instance) { + return new StateAssert<>(RequestResponseRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestStreamRequesterFlux instance) { + return new StateAssert<>(RequestStreamRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelRequesterFlux instance) { + return new StateAssert<>(RequestChannelRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelResponderSubscriber instance) { + return new StateAssert<>(RequestChannelResponderSubscriber.STATE, instance); + } + + private final Failures failures = Failures.instance(); + private final T instance; + + public StateAssert(AtomicLongFieldUpdater updater, T instance) { + super(updater, StateAssert.class); + this.instance = instance; + } + + public StateAssert isUnsubscribed() { + long currentState = actual.get(instance); + if (isSubscribed(currentState) || StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, UNSUBSCRIBED_STATE)); + } + return this; + } + + public StateAssert hasSubscribedFlagOnly() { + long currentState = actual.get(instance); + if (currentState != SUBSCRIBED_FLAG) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasSubscribedFlag() { + long currentState = actual.get(instance); + if (!isSubscribed(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasRequestN(long n) { + long currentState = actual.get(instance); + if (extractRequestN(currentState) != n) { + throw failures.failure(info, shouldHaveRequestN(currentState, n)); + } + return this; + } + + public StateAssert hasRequestNBetween(long min, long max) { + long currentState = actual.get(instance); + final long requestN = extractRequestN(currentState); + if (requestN < min || requestN > max) { + throw failures.failure(info, shouldHaveRequestNBetween(currentState, min, max)); + } + return this; + } + + public StateAssert hasFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (!isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasNoFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasReassemblingFlag() { + long currentState = actual.get(instance); + if (!isReassembling(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasNoReassemblingFlag() { + long currentState = actual.get(instance); + if (isReassembling(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasInboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isInboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, INBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert hasOutboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isOutboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, OUTBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert isTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, TERMINATED_STATE)); + } + return this; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java new file mode 100644 index 000000000..16bd9f16e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -0,0 +1,120 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import org.junit.jupiter.api.Test; + +public class StreamIdSupplierTest { + @Test + public void testClientSequence() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.clientSupplier(); + assertThat(s.nextStreamId(map)).isEqualTo(1); + assertThat(s.nextStreamId(map)).isEqualTo(3); + assertThat(s.nextStreamId(map)).isEqualTo(5); + } + + @Test + public void testServerSequence() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.serverSupplier(); + assertEquals(2, s.nextStreamId(map)); + assertEquals(4, s.nextStreamId(map)); + assertEquals(6, s.nextStreamId(map)); + } + + @Test + public void testClientIsValid() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.clientSupplier(); + + assertFalse(s.isBeforeOrCurrent(1)); + assertFalse(s.isBeforeOrCurrent(3)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(1)); + assertFalse(s.isBeforeOrCurrent(3)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(3)); + + // negative + assertFalse(s.isBeforeOrCurrent(-1)); + // connection + assertFalse(s.isBeforeOrCurrent(0)); + // server also accepted (checked externally) + assertTrue(s.isBeforeOrCurrent(2)); + } + + @Test + public void testServerIsValid() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.serverSupplier(); + + assertFalse(s.isBeforeOrCurrent(2)); + assertFalse(s.isBeforeOrCurrent(4)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(2)); + assertFalse(s.isBeforeOrCurrent(4)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(4)); + + // negative + assertFalse(s.isBeforeOrCurrent(-2)); + // connection + assertFalse(s.isBeforeOrCurrent(0)); + // client also accepted (checked externally) + assertTrue(s.isBeforeOrCurrent(1)); + } + + @Test + public void testWrap() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = new StreamIdSupplier(Integer.MAX_VALUE - 3); + + assertEquals(2147483646, s.nextStreamId(map)); + assertEquals(2, s.nextStreamId(map)); + assertEquals(4, s.nextStreamId(map)); + + s = new StreamIdSupplier(Integer.MAX_VALUE - 2); + + assertEquals(2147483647, s.nextStreamId(map)); + assertEquals(1, s.nextStreamId(map)); + assertEquals(3, s.nextStreamId(map)); + } + + @Test + public void testSkipFound() { + IntObjectMap map = new IntObjectHashMap<>(); + map.put(5, new Object()); + map.put(9, new Object()); + StreamIdSupplier s = StreamIdSupplier.clientSupplier(); + assertEquals(1, s.nextStreamId(map)); + assertEquals(3, s.nextStreamId(map)); + assertEquals(7, s.nextStreamId(map)); + assertEquals(11, s.nextStreamId(map)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..e282d72d5 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,281 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import reactor.core.Exceptions; +import reactor.util.annotation.Nullable; + +final class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { + + static final String DATA_CONTENT = "testData"; + static final String METADATA_CONTENT = "testMetadata"; + + final Throwable error; + + TestRequesterResponderSupport( + @Nullable Throwable error, + StreamIdSupplier streamIdSupplier, + DuplexConnection connection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + PayloadDecoder.ZERO_COPY, + connection, + streamIdSupplier, + (__) -> requestInterceptor); + this.error = error; + } + + @Override + public TestDuplexConnection getDuplexConnection() { + return (TestDuplexConnection) super.getDuplexConnection(); + } + + static Payload genericPayload(LeaksTrackingByteBufAllocator allocator) { + ByteBuf data = allocator.buffer(); + data.writeCharSequence(DATA_CONTENT, CharsetUtil.UTF_8); + + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence(METADATA_CONTENT, CharsetUtil.UTF_8); + + return ByteBufPayload.create(data, metadata); + } + + static Payload fixedSizePayload(LeaksTrackingByteBufAllocator allocator, int contentSize) { + final int dataSize = ThreadLocalRandom.current().nextInt(0, contentSize); + final byte[] dataBytes = new byte[dataSize]; + ThreadLocalRandom.current().nextBytes(dataBytes); + ByteBuf data = allocator.buffer(dataSize); + data.writeBytes(dataBytes); + + ByteBuf metadata; + int metadataSize = contentSize - dataSize; + if (metadataSize > 0) { + final byte[] metadataBytes = new byte[metadataSize]; + metadata = allocator.buffer(metadataSize); + metadata.writeBytes(metadataBytes); + } else { + metadata = ThreadLocalRandom.current().nextBoolean() ? Unpooled.EMPTY_BUFFER : null; + } + + return ByteBufPayload.create(data, metadata); + } + + static Payload randomPayload(LeaksTrackingByteBufAllocator allocator) { + boolean hasMetadata = ThreadLocalRandom.current().nextBoolean(); + ByteBuf metadataByteBuf; + if (hasMetadata) { + byte[] randomMetadata = new byte[ThreadLocalRandom.current().nextInt(0, 512)]; + ThreadLocalRandom.current().nextBytes(randomMetadata); + metadataByteBuf = allocator.buffer().writeBytes(randomMetadata); + } else { + metadataByteBuf = null; + } + byte[] randomData = new byte[ThreadLocalRandom.current().nextInt(512, 1024)]; + ThreadLocalRandom.current().nextBytes(randomData); + + ByteBuf dataByteBuf = allocator.buffer().writeBytes(randomData); + return ByteBufPayload.create(dataByteBuf, metadataByteBuf); + } + + static Payload randomMetadataOnlyPayload(LeaksTrackingByteBufAllocator allocator) { + byte[] randomMetadata = new byte[ThreadLocalRandom.current().nextInt(512, 1024)]; + ThreadLocalRandom.current().nextBytes(randomMetadata); + ByteBuf metadataByteBuf = allocator.buffer().writeBytes(randomMetadata); + + return ByteBufPayload.create(Unpooled.EMPTY_BUFFER, metadataByteBuf); + } + + static ArrayList prepareFragments( + LeaksTrackingByteBufAllocator allocator, int mtu, Payload payload) { + + return prepareFragments(allocator, mtu, payload, FrameType.NEXT_COMPLETE); + } + + static ArrayList prepareFragments( + LeaksTrackingByteBufAllocator allocator, int mtu, Payload payload, FrameType frameType) { + + boolean hasMetadata = payload.hasMetadata(); + ByteBuf data = payload.sliceData(); + ByteBuf metadata = payload.sliceMetadata(); + ArrayList fragments = new ArrayList<>(); + + fragments.add( + frameType.hasInitialRequestN() + ? FragmentationUtils.encodeFirstFragment( + allocator, mtu, 1L, frameType, 1, hasMetadata, metadata, data) + : FragmentationUtils.encodeFirstFragment( + allocator, mtu, frameType, 1, hasMetadata, metadata, data)); + + while (metadata.isReadable() || data.isReadable()) { + fragments.add( + FragmentationUtils.encodeFollowsFragment(allocator, mtu, 1, true, metadata, data)); + } + + return fragments; + } + + @Override + public synchronized int getNextStreamId() { + int nextStreamId = super.getNextStreamId(); + + if (error != null) { + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + @Override + public synchronized int addAndGetNextStreamId(FrameHandler frameHandler) { + int nextStreamId = super.addAndGetNextStreamId(frameHandler); + + if (error != null) { + super.remove(nextStreamId, frameHandler); + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + public static TestRequesterResponderSupport client( + @Nullable Throwable e, @Nullable RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor, + e); + } + + public static TestRequesterResponderSupport client(@Nullable Throwable e) { + return client(0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, e); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize, @Nullable Throwable e) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + mtu, + maxFrameLength, + maxInboundPayloadSize, + null, + e); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize) { + return client(duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + return client( + duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, requestInterceptor, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor, + @Nullable Throwable e) { + return new TestRequesterResponderSupport( + e, + StreamIdSupplier.clientSupplier(), + duplexConnection, + mtu, + maxFrameLength, + maxInboundPayloadSize, + requestInterceptor); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize) { + return client(mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client(int mtu, int maxFrameLength) { + return client(mtu, maxFrameLength, Integer.MAX_VALUE); + } + + public static TestRequesterResponderSupport client(int mtu) { + return client(mtu, FRAME_LENGTH_MASK); + } + + public static TestRequesterResponderSupport client() { + return client(0); + } + + public static TestRequesterResponderSupport client(RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor); + } + + public TestRequesterResponderSupport assertNoActiveStreams() { + Assertions.assertThat(activeStreams).isEmpty(); + return this; + } + + public TestRequesterResponderSupport assertHasStream(int i, FrameHandler stream) { + Assertions.assertThat(activeStreams).containsEntry(i, stream); + return this; + } + + @Override + public LeaksTrackingByteBufAllocator getAllocator() { + return (LeaksTrackingByteBufAllocator) super.getAllocator(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ApplicationErrorExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ApplicationErrorExceptionTest.java new file mode 100644 index 000000000..35b30b951 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ApplicationErrorExceptionTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class ApplicationErrorExceptionTest + implements RSocketExceptionTest { + + @Override + public ApplicationErrorException getException(String message) { + return new ApplicationErrorException(message); + } + + @Override + public ApplicationErrorException getException(String message, Throwable cause) { + return new ApplicationErrorException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000201; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidRequestException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/CanceledExceptionTest.java similarity index 53% rename from rsocket-core/src/main/java/io/rsocket/exceptions/InvalidRequestException.java rename to rsocket-core/src/test/java/io/rsocket/exceptions/CanceledExceptionTest.java index 116ec1e94..6df9e6a4d 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidRequestException.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/CanceledExceptionTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.exceptions; - -import io.rsocket.frame.ErrorFrameFlyweight; -public class InvalidRequestException extends RSocketException { +package io.rsocket.exceptions; - private static final long serialVersionUID = 812240443606264942L; +final class CanceledExceptionTest implements RSocketExceptionTest { - public InvalidRequestException(String message) { - super(message); + @Override + public CanceledException getException(String message) { + return new CanceledException(message); } - public InvalidRequestException(String message, Throwable cause) { - super(message, cause); + @Override + public CanceledException getException(String message, Throwable cause) { + return new CanceledException(message, cause); } @Override - public int errorCode() { - return ErrorFrameFlyweight.INVALID; + public int getSpecifiedErrorCode() { + return 0x00000203; } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionCloseExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionCloseExceptionTest.java new file mode 100644 index 000000000..fe98b55de --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionCloseExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class ConnectionCloseExceptionTest implements RSocketExceptionTest { + + @Override + public ConnectionCloseException getException(String message) { + return new ConnectionCloseException(message); + } + + @Override + public ConnectionCloseException getException(String message, Throwable cause) { + return new ConnectionCloseException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000102; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionErrorExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionErrorExceptionTest.java new file mode 100644 index 000000000..a2bd45a38 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionErrorExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class ConnectionErrorExceptionTest implements RSocketExceptionTest { + + @Override + public ConnectionErrorException getException(String message) { + return new ConnectionErrorException(message); + } + + @Override + public ConnectionErrorException getException(String message, Throwable cause) { + return new ConnectionErrorException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000101; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java new file mode 100644 index 000000000..a316aed8b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -0,0 +1,283 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.ErrorFrameCodec; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class ExceptionsTest { + @DisplayName("from returns ApplicationErrorException") + @Test + void fromApplicationException() { + ByteBuf byteBuf = createErrorFrame(1, APPLICATION_ERROR, "test-message"); + + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns CanceledException") + @Test + void fromCanceledException() { + ByteBuf byteBuf = createErrorFrame(1, CANCELED, "test-message"); + + try { + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CanceledException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns ConnectionCloseException") + @Test + void fromConnectionCloseException() { + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_CLOSE, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionCloseException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns ConnectionErrorException") + @Test + void fromConnectionErrorException() { + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_ERROR, "test-message"); + + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionErrorException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns IllegalArgumentException if error frame has illegal error code") + @Test + void fromIllegalErrorFrame() { + ByteBuf byteBuf = createErrorFrame(0, 0x00000000, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") + .isInstanceOf(IllegalArgumentException.class); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns InvalidException") + @Test + void fromInvalidException() { + ByteBuf byteBuf = createErrorFrame(1, INVALID, "test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(InvalidException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns InvalidSetupException") + @Test + void fromInvalidSetupException() { + ByteBuf byteBuf = createErrorFrame(0, INVALID_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(InvalidSetupException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns RejectedException") + @Test + void fromRejectedException() { + ByteBuf byteBuf = createErrorFrame(1, REJECTED, "test-message"); + try { + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(RejectedException.class) + .withFailMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns RejectedResumeException") + @Test + void fromRejectedResumeException() { + ByteBuf byteBuf = createErrorFrame(0, REJECTED_RESUME, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedResumeException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns RejectedSetupException") + @Test + void fromRejectedSetupException() { + ByteBuf byteBuf = createErrorFrame(0, REJECTED_SETUP, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedSetupException.class) + .withFailMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns UnsupportedSetupException") + @Test + void fromUnsupportedSetupException() { + ByteBuf byteBuf = createErrorFrame(0, UNSUPPORTED_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(UnsupportedSetupException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns CustomRSocketException") + @Test + void fromCustomRSocketException() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + int randomCode = + ThreadLocalRandom.current().nextBoolean() + ? ThreadLocalRandom.current() + .nextInt(Integer.MIN_VALUE, ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE) + : ThreadLocalRandom.current() + .nextInt(ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE, Integer.MAX_VALUE); + ByteBuf byteBuf = createErrorFrame(0, randomCode, "test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + } + + @DisplayName("from throws NullPointerException with null frame") + @Test + void fromWithNullFrame() { + assertThatNullPointerException() + .isThrownBy(() -> Exceptions.from(0, null)) + .withMessage("frame must not be null"); + } + + private ByteBuf createErrorFrame(int streamId, int errorCode, String message) { + return ErrorFrameCodec.encode( + UnpooledByteBufAllocator.DEFAULT, streamId, new TestRSocketException(errorCode, message)); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CancelException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidExceptionTest.java similarity index 53% rename from rsocket-core/src/main/java/io/rsocket/exceptions/CancelException.java rename to rsocket-core/src/test/java/io/rsocket/exceptions/InvalidExceptionTest.java index 05a18328a..a7dec62b4 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/CancelException.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidExceptionTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.exceptions; - -import io.rsocket.frame.ErrorFrameFlyweight; -public class CancelException extends RSocketException { +package io.rsocket.exceptions; - private static final long serialVersionUID = 3579712120019438212L; +final class InvalidExceptionTest implements RSocketExceptionTest { - public CancelException(String message) { - super(message); + @Override + public InvalidException getException(String message) { + return new InvalidException(message); } - public CancelException(String message, Throwable cause) { - super(message, cause); + @Override + public InvalidException getException(String message, Throwable cause) { + return new InvalidException(message, cause); } @Override - public int errorCode() { - return ErrorFrameFlyweight.CANCELED; + public int getSpecifiedErrorCode() { + return 0x00000204; } } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidSetupExceptionTest.java similarity index 52% rename from rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionException.java rename to rsocket-core/src/test/java/io/rsocket/exceptions/InvalidSetupExceptionTest.java index bb04f79fb..d7fce8cc8 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionException.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidSetupExceptionTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.exceptions; - -import io.rsocket.frame.ErrorFrameFlyweight; -public class ConnectionException extends RSocketException implements Retryable { +package io.rsocket.exceptions; - private static final long serialVersionUID = -6565180364631212778L; +final class InvalidSetupExceptionTest implements RSocketExceptionTest { - public ConnectionException(String message) { - super(message); + @Override + public InvalidSetupException getException(String message) { + return new InvalidSetupException(message); } - public ConnectionException(String message, Throwable cause) { - super(message, cause); + @Override + public InvalidSetupException getException(String message, Throwable cause) { + return new InvalidSetupException(message, cause); } @Override - public int errorCode() { - return ErrorFrameFlyweight.CONNECTION_ERROR; + public int getSpecifiedErrorCode() { + return 0x00000001; } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java new file mode 100644 index 000000000..9aa8fc364 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.RSocketErrorException; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +interface RSocketExceptionTest { + + @DisplayName("constructor does not throw NullPointerException with null message") + @Test + default void constructorWithNullMessage() { + assertThat(getException(null)).hasMessage(null); + } + + @DisplayName("constructor does not throw NullPointerException with null message and cause") + @Test + default void constructorWithNullMessageAndCause() { + assertThat(getException(null)).hasMessage(null); + } + + @DisplayName("errorCode returns specified value") + @Test + default void errorCodeReturnsSpecifiedValue() { + assertThat(getException("test-message").errorCode()).isEqualTo(getSpecifiedErrorCode()); + } + + T getException(String message, Throwable cause); + + T getException(String message); + + int getSpecifiedErrorCode(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedExceptionTest.java similarity index 53% rename from rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationException.java rename to rsocket-core/src/test/java/io/rsocket/exceptions/RejectedExceptionTest.java index 8f8bf38c5..209595596 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationException.java +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedExceptionTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,24 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.exceptions; - -import io.rsocket.frame.ErrorFrameFlyweight; -public class ApplicationException extends RSocketException { +package io.rsocket.exceptions; - private static final long serialVersionUID = -8801579369150844447L; +final class RejectedExceptionTest implements RSocketExceptionTest { - public ApplicationException(String message) { - super(message); + @Override + public RejectedException getException(String message) { + return new RejectedException(message); } - public ApplicationException(String message, Throwable cause) { - super(message, cause); + @Override + public RejectedException getException(String message, Throwable cause) { + return new RejectedException(message, cause); } @Override - public int errorCode() { - return ErrorFrameFlyweight.APPLICATION_ERROR; + public int getSpecifiedErrorCode() { + return 0x00000202; } } diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedResumeExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedResumeExceptionTest.java new file mode 100644 index 000000000..555ff160d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedResumeExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class RejectedResumeExceptionTest implements RSocketExceptionTest { + + @Override + public RejectedResumeException getException(String message) { + return new RejectedResumeException(message); + } + + @Override + public RejectedResumeException getException(String message, Throwable cause) { + return new RejectedResumeException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000004; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedSetupExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedSetupExceptionTest.java new file mode 100644 index 000000000..2fe63c09d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedSetupExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class RejectedSetupExceptionTest implements RSocketExceptionTest { + + @Override + public RejectedSetupException getException(String message) { + return new RejectedSetupException(message); + } + + @Override + public RejectedSetupException getException(String message, Throwable cause) { + return new RejectedSetupException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000003; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java new file mode 100644 index 000000000..15685aa43 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java @@ -0,0 +1,42 @@ +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; + +public class TestRSocketException extends RSocketErrorException { + private static final long serialVersionUID = 7873267740343446585L; + + private final int errorCode; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code + * @param message the message + * @throws NullPointerException if {@code message} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message) { + super(ErrorFrameCodec.APPLICATION_ERROR, message); + this.errorCode = errorCode; + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code + * @param message the message + * @param cause the cause of this exception + * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message, Throwable cause) { + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/UnsupportedSetupExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/UnsupportedSetupExceptionTest.java new file mode 100644 index 000000000..6c73ff564 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/UnsupportedSetupExceptionTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class UnsupportedSetupExceptionTest + implements RSocketExceptionTest { + + @Override + public UnsupportedSetupException getException(String message) { + return new UnsupportedSetupException(message); + } + + @Override + public UnsupportedSetupException getException(String message, Throwable cause) { + return new UnsupportedSetupException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000002; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java deleted file mode 100644 index 2228851d7..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FragmentationDuplexConnectionTest.java +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.util.PayloadImpl; -import java.nio.ByteBuffer; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Test; -import org.reactivestreams.Publisher; -import reactor.core.publisher.EmitterProcessor; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -/** */ -public class FragmentationDuplexConnectionTest { - @Test - public void testSendOneWithFragmentation() { - DuplexConnection mockConnection = mock(DuplexConnection.class); - when(mockConnection.send(any())) - .then( - invocation -> { - Publisher frames = invocation.getArgument(0); - - StepVerifier.create(frames).expectNextCount(16).verifyComplete(); - - return Mono.empty(); - }); - when(mockConnection.sendOne(any(Frame.class))).thenReturn(Mono.empty()); - - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - - Frame frame = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FragmentationDuplexConnection duplexConnection = - new FragmentationDuplexConnection(mockConnection, 2); - - StepVerifier.create(duplexConnection.sendOne(frame)).verifyComplete(); - } - - @Test - public void testShouldNotFragment() { - DuplexConnection mockConnection = mock(DuplexConnection.class); - when(mockConnection.sendOne(any(Frame.class))).thenReturn(Mono.empty()); - - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - - Frame frame = Frame.Cancel.from(1); - - FragmentationDuplexConnection duplexConnection = - new FragmentationDuplexConnection(mockConnection, 2); - - StepVerifier.create(duplexConnection.sendOne(frame)).verifyComplete(); - - verify(mockConnection, times(1)).sendOne(frame); - } - - @Test - public void testShouldFragmentMultiple() { - DuplexConnection mockConnection = mock(DuplexConnection.class); - when(mockConnection.send(any())) - .then( - invocation -> { - Publisher frames = invocation.getArgument(0); - - StepVerifier.create(frames).expectNextCount(16).verifyComplete(); - - return Mono.empty(); - }); - when(mockConnection.sendOne(any(Frame.class))).thenReturn(Mono.empty()); - - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - - Frame frame1 = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - Frame frame2 = - Frame.Request.from(2, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - Frame frame3 = - Frame.Request.from(3, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FragmentationDuplexConnection duplexConnection = - new FragmentationDuplexConnection(mockConnection, 2); - - StepVerifier.create(duplexConnection.send(Flux.just(frame1, frame2, frame3))).verifyComplete(); - - verify(mockConnection, times(3)).send(any()); - } - - @Test - public void testReassembleFragmentFrame() { - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - Frame frame = - Frame.Request.from(1024, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - Flux fragmentedFrames = frameFragmenter.fragment(frame); - EmitterProcessor processor = EmitterProcessor.create(128); - DuplexConnection mockConnection = mock(DuplexConnection.class); - when(mockConnection.receive()).then(answer -> processor); - - FragmentationDuplexConnection duplexConnection = - new FragmentationDuplexConnection(mockConnection, 2); - - fragmentedFrames.subscribe(processor); - - duplexConnection - .receive() - .log() - .doOnNext(c -> System.out.println("here - " + c.toString())) - .subscribe(); - } - - private ByteBuffer createRandomBytes(int size) { - byte[] bytes = new byte[size]; - ThreadLocalRandom.current().nextBytes(bytes); - return ByteBuffer.wrap(bytes); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java deleted file mode 100644 index e6119d294..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameFragmenterTest.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.util.PayloadImpl; -import java.nio.ByteBuffer; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Test; -import reactor.test.StepVerifier; - -public class FrameFragmenterTest { - @Test - public void testFragmentWithMetadataAndData() { - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - - Frame from = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - - StepVerifier.create(frameFragmenter.fragment(from)).expectNextCount(16).verifyComplete(); - } - - @Test - public void testFragmentWithMetadataAndDataWithOddData() { - ByteBuffer data = createRandomBytes(17); - ByteBuffer metadata = createRandomBytes(17); - - Frame from = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - - StepVerifier.create(frameFragmenter.fragment(from)).expectNextCount(17).verifyComplete(); - } - - @Test - public void testFragmentWithMetadataOnly() { - ByteBuffer data = ByteBuffer.allocate(0); - ByteBuffer metadata = createRandomBytes(16); - - Frame from = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - - StepVerifier.create(frameFragmenter.fragment(from)).expectNextCount(8).verifyComplete(); - } - - @Test - public void testFragmentWithDataOnly() { - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = ByteBuffer.allocate(0); - - Frame from = - Frame.Request.from(1, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - - StepVerifier.create(frameFragmenter.fragment(from)).expectNextCount(8).verifyComplete(); - } - - private ByteBuffer createRandomBytes(int size) { - byte[] bytes = new byte[size]; - ThreadLocalRandom.current().nextBytes(bytes); - return ByteBuffer.wrap(bytes); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java b/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java deleted file mode 100644 index 278c7fece..000000000 --- a/rsocket-core/src/test/java/io/rsocket/fragmentation/FrameReassemblerTest.java +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.fragmentation; - -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.util.PayloadImpl; -import java.nio.ByteBuffer; -import java.util.concurrent.ThreadLocalRandom; -import org.junit.Test; - -/** */ -public class FrameReassemblerTest { - @Test - public void testAppend() { - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - - Frame from = - Frame.Request.from(1024, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - FrameReassembler reassembler = new FrameReassembler(from); - frameFragmenter.fragment(from).subscribe(reassembler::append); - } - - private ByteBuffer createRandomBytes(int size) { - byte[] bytes = new byte[size]; - ThreadLocalRandom.current().nextBytes(bytes); - return ByteBuffer.wrap(bytes); - } - /* - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - - Frame from = Frame.Request.from(1024, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - - FrameReassembler reassembler = new FrameReassembler(2); - - frameFragmenter - .fragment(from) - .log() - .doOnNext(reassembler::append) - .blockLast(); - - Frame reassemble = reassembler.reassemble(); - - Assert.assertEquals(reassemble.getStreamId(), from.getStreamId()); - Assert.assertEquals(reassemble.getType(), from.getType()); - - ByteBuffer reassembleData = reassemble.getData(); - ByteBuffer reassembleMetadata = reassemble.getMetadata(); - - Assert.assertTrue(reassembleData.hasRemaining()); - Assert.assertTrue(reassembleMetadata.hasRemaining()); - - while (reassembleData.hasRemaining()) { - Assert.assertEquals(reassembleData.get(), data.get()); - } - - while (reassembleMetadata.hasRemaining()) { - Assert.assertEquals(reassembleMetadata.get(), metadata.get()); - } - } - - @Test - public void testReassmembleAndClear() { - ByteBuffer data = createRandomBytes(16); - ByteBuffer metadata = createRandomBytes(16); - - Frame request = Frame.Request.from(1024, FrameType.REQUEST_RESPONSE, new PayloadImpl(data, metadata), 1); - - FrameFragmenter frameFragmenter = new FrameFragmenter(2); - - FrameReassembler reassembler = new FrameReassembler(2); - - Iterable fragments = frameFragmenter - .fragment(request) - .log() - .map(frame -> frame.content().copy()) - .toIterable(); - - fragments - .forEach(f -> ByteBufUtil.prettyHexDump(f)); - - - for (int i = 0; i < 5; i++) { - for (ByteBuf frame : fragments) { - reassembler - .append(Frame.from(frame)); - } - - Frame reassemble = reassembler.reassemble(); - - Assert.assertEquals(reassemble.getStreamId(), request.getStreamId()); - Assert.assertEquals(reassemble.getType(), reassemble.getType()); - - ByteBuffer reassembleData = reassemble.getData(); - ByteBuffer reassembleMetadata = reassemble.getMetadata(); - - Assert.assertTrue(reassembleData.hasRemaining()); - Assert.assertTrue(reassembleMetadata.hasRemaining()); - - while (reassembleData.hasRemaining()) { - Assert.assertEquals(reassembleData.get(), data.get()); - } - - while (reassembleMetadata.hasRemaining()) { - Assert.assertEquals(reassembleMetadata.get(), metadata.get()); - } - - } - } - - @Test - public void substring() { - String s = "1234567890"; - String substring = s.substring(0, 5); - System.out.println(substring); - String substring1 = s.substring(5, 10); - System.out.println(substring1); - } - - */ -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java new file mode 100644 index 000000000..b12d72b51 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.api.Assertions; +import org.assertj.core.presentation.StandardRepresentation; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; + +public final class ByteBufRepresentation extends StandardRepresentation + implements BeforeAllCallback { + + @Override + public void beforeAll(ExtensionContext context) { + Assertions.useRepresentation(this); + } + + @Override + protected String fallbackToStringOf(Object object) { + if (object instanceof ByteBuf) { + try { + String normalBufferString = object.toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 128) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } + } catch (IllegalReferenceCountException e) { + // noops + } + } + + return super.fallbackToStringOf(object); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java new file mode 100644 index 000000000..dc04c1141 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java @@ -0,0 +1,21 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.exceptions.ApplicationErrorException; +import org.junit.jupiter.api.Test; + +class ErrorFrameCodecTest { + @Test + void testEncode() { + ByteBuf frame = + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, new ApplicationErrorException("d")); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + assertEquals("00000b000000012c000000020164", ByteBufUtil.hexDump(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java deleted file mode 100644 index 959938a8e..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameFlyweightTest.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.frame; - -import static io.rsocket.frame.ErrorFrameFlyweight.*; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.rsocket.Frame; -import io.rsocket.exceptions.*; -import java.nio.charset.StandardCharsets; -import org.junit.Test; - -public class ErrorFrameFlyweightTest { - private final ByteBuf byteBuf = Unpooled.buffer(1024); - - @Test - public void testEncoding() { - int encoded = - ErrorFrameFlyweight.encode( - byteBuf, - 1, - ErrorFrameFlyweight.APPLICATION_ERROR, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - assertEquals("00000b000000012c000000020164", ByteBufUtil.hexDump(byteBuf, 0, encoded)); - - assertEquals(ErrorFrameFlyweight.APPLICATION_ERROR, ErrorFrameFlyweight.errorCode(byteBuf)); - assertEquals("d", ErrorFrameFlyweight.message(byteBuf)); - } - - @Test - public void testExceptions() throws Exception { - assertExceptionMapping(INVALID_SETUP, InvalidSetupException.class); - assertExceptionMapping(UNSUPPORTED_SETUP, UnsupportedSetupException.class); - assertExceptionMapping(REJECTED_SETUP, RejectedSetupException.class); - assertExceptionMapping(REJECTED_RESUME, RejectedResumeException.class); - assertExceptionMapping(CONNECTION_ERROR, ConnectionException.class); - assertExceptionMapping(CONNECTION_CLOSE, ConnectionCloseException.class); - assertExceptionMapping(APPLICATION_ERROR, ApplicationException.class); - assertExceptionMapping(REJECTED, RejectedException.class); - assertExceptionMapping(CANCELED, CancelException.class); - assertExceptionMapping(INVALID, InvalidRequestException.class); - } - - private void assertExceptionMapping(int errorCode, Class exceptionClass) - throws Exception { - T ex = exceptionClass.getConstructor(String.class).newInstance("error data"); - Frame f = Frame.Error.from(0, ex); - - assertEquals(errorCode, Frame.Error.errorCode(f)); - - RuntimeException ex2 = Exceptions.from(f); - - assertEquals(ex.getMessage(), ex2.getMessage()); - assertTrue(exceptionClass.isInstance(ex2)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java new file mode 100644 index 000000000..28209393e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java @@ -0,0 +1,62 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ExtensionFrameCodecTest { + + @Test + void extensionDataMetadata() { + ByteBuf metadata = bytebuf("md"); + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, metadata, data); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionData() { + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, null, data); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertNull(ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionMetadata() { + ByteBuf metadata = bytebuf("md"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode( + ByteBufAllocator.DEFAULT, 1, extendedType, metadata, Unpooled.EMPTY_BUFFER); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(0, ExtensionFrameCodec.data(extension).readableBytes()); + extension.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java new file mode 100644 index 000000000..15788e631 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java @@ -0,0 +1,36 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.junit.jupiter.api.Test; + +class FrameHeaderCodecTest { + // Taken from spec + private static final int FRAME_MAX_SIZE = 16_777_215; + + @Test + void typeAndFlag() { + FrameType frameType = FrameType.REQUEST_FNF; + int flags = 0b1110110111; + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + + assertEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); + header.release(); + } + + @Test + void typeAndFlagTruncated() { + FrameType frameType = FrameType.SETUP; + int flags = 0b11110110111; // 1 bit too many + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + + assertNotEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(flags & 0b0000_0011_1111_1111, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); + header.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java deleted file mode 100644 index fca4c1095..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderFlyweightTest.java +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.frame; - -import static io.rsocket.frame.FrameHeaderFlyweight.FLAGS_M; -import static io.rsocket.frame.FrameHeaderFlyweight.FRAME_HEADER_LENGTH; -import static org.junit.Assert.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.rsocket.FrameType; -import org.junit.Test; - -public class FrameHeaderFlyweightTest { - // Taken from spec - private static final int FRAME_MAX_SIZE = 16_777_215; - - private final ByteBuf byteBuf = Unpooled.buffer(1024); - - @Test - public void headerSize() { - int frameLength = 123456; - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, 0, FrameType.SETUP, 0); - assertEquals(frameLength, FrameHeaderFlyweight.frameLength(byteBuf)); - } - - @Test - public void headerSizeMax() { - int frameLength = FRAME_MAX_SIZE; - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, frameLength, 0, FrameType.SETUP, 0); - assertEquals(frameLength, FrameHeaderFlyweight.frameLength(byteBuf)); - } - - @Test(expected = IllegalArgumentException.class) - public void headerSizeTooLarge() { - FrameHeaderFlyweight.encodeFrameHeader(byteBuf, FRAME_MAX_SIZE + 1, 0, FrameType.SETUP, 0); - } - - @Test - public void frameLength() { - int length = - FrameHeaderFlyweight.encode( - byteBuf, 0, FLAGS_M, FrameType.SETUP, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); - assertEquals(length, 12); // 72 bits - } - - @Test - public void frameLengthNullMetadata() { - int length = - FrameHeaderFlyweight.encode(byteBuf, 0, 0, FrameType.SETUP, null, Unpooled.EMPTY_BUFFER); - assertEquals(length, 9); // 72 bits - } - - @Test - public void metadataLength() { - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - FrameHeaderFlyweight.encode( - byteBuf, 0, FLAGS_M, FrameType.SETUP, metadata, Unpooled.EMPTY_BUFFER); - assertEquals( - 4, - FrameHeaderFlyweight.decodeMetadataLength(byteBuf, FrameHeaderFlyweight.FRAME_HEADER_LENGTH) - .longValue()); - } - - @Test - public void dataLength() { - ByteBuf data = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5}); - int length = - FrameHeaderFlyweight.encode( - byteBuf, 0, FLAGS_M, FrameType.SETUP, Unpooled.EMPTY_BUFFER, data); - assertEquals( - 5, - FrameHeaderFlyweight.dataLength( - byteBuf, FrameType.SETUP, FrameHeaderFlyweight.FRAME_HEADER_LENGTH)); - } - - @Test - public void metadataSlice() { - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - FrameHeaderFlyweight.encode( - byteBuf, 0, FLAGS_M, FrameType.REQUEST_RESPONSE, metadata, Unpooled.EMPTY_BUFFER); - metadata.resetReaderIndex(); - - assertEquals(metadata, FrameHeaderFlyweight.sliceFrameMetadata(byteBuf)); - } - - @Test - public void dataSlice() { - ByteBuf data = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5}); - FrameHeaderFlyweight.encode( - byteBuf, 0, FLAGS_M, FrameType.REQUEST_RESPONSE, Unpooled.EMPTY_BUFFER, data); - data.resetReaderIndex(); - - assertEquals(data, FrameHeaderFlyweight.sliceFrameData(byteBuf)); - } - - @Test - public void streamId() { - int streamId = 1234; - FrameHeaderFlyweight.encode( - byteBuf, streamId, FLAGS_M, FrameType.SETUP, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); - assertEquals(streamId, FrameHeaderFlyweight.streamId(byteBuf)); - } - - @Test - public void typeAndFlag() { - FrameType frameType = FrameType.FIRE_AND_FORGET; - int flags = 0b1110110111; - FrameHeaderFlyweight.encode( - byteBuf, 0, flags, frameType, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); - - assertEquals(flags, FrameHeaderFlyweight.flags(byteBuf)); - assertEquals(frameType, FrameHeaderFlyweight.frameType(byteBuf)); - } - - @Test - public void typeAndFlagTruncated() { - FrameType frameType = FrameType.SETUP; - int flags = 0b11110110111; // 1 bit too many - FrameHeaderFlyweight.encode( - byteBuf, 0, flags, FrameType.SETUP, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); - - assertNotEquals(flags, FrameHeaderFlyweight.flags(byteBuf)); - assertEquals(flags & 0b0000_0011_1111_1111, FrameHeaderFlyweight.flags(byteBuf)); - assertEquals(frameType, FrameHeaderFlyweight.frameType(byteBuf)); - } - - @Test - public void missingMetadataLength() { - for (FrameType frameType : FrameType.values()) { - switch (frameType) { - case UNDEFINED: - break; - case CANCEL: - case METADATA_PUSH: - case LEASE: - assertFalse( - "!hasMetadataLengthField(): " + frameType, - FrameHeaderFlyweight.hasMetadataLengthField(frameType)); - break; - default: - if (frameType.canHaveMetadata()) { - assertTrue( - "hasMetadataLengthField(): " + frameType, - FrameHeaderFlyweight.hasMetadataLengthField(frameType)); - } - } - } - } - - @Test - public void wireFormat() { - ByteBuf expectedBuffer = Unpooled.buffer(1024); - int currentIndex = 0; - // frame length - int frameLength = - FrameHeaderFlyweight.FRAME_HEADER_LENGTH - FrameHeaderFlyweight.FRAME_LENGTH_SIZE; - expectedBuffer.setInt(currentIndex, frameLength << 8); - currentIndex += 3; - // stream id - expectedBuffer.setInt(currentIndex, 5); - currentIndex += Integer.BYTES; - // flags and frame type - expectedBuffer.setShort(currentIndex, (short) 0b001010_0001100000); - currentIndex += Short.BYTES; - - FrameType frameType = FrameType.NEXT_COMPLETE; - FrameHeaderFlyweight.encode(byteBuf, 5, 0, frameType, null, Unpooled.EMPTY_BUFFER); - - ByteBuf expected = expectedBuffer.slice(0, currentIndex); - ByteBuf actual = byteBuf.slice(0, FRAME_HEADER_LENGTH); - - assertEquals(ByteBufUtil.hexDump(expected), ByteBufUtil.hexDump(actual)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java new file mode 100644 index 000000000..ac19dc754 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java @@ -0,0 +1,264 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class GenericFrameCodecTest { + @Test + void testEncoding() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length + // | | | | ⌌Encoded Metadata + // | | | | | ⌌Encoded Data + // __|________|_________|______|____|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "000010000000011900000000010000026d6464"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithEmptyMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length (0) + // | | | | ⌌Encoded Data + // __|________|_________|_______|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "00000e0000000119000000000100000064"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithNullMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Data + // __|________|_________|_____| + // ↓<-> ↓↓ <-> ↓↓ <-> ↓↓↓ + String expected = "00000b0000000118000000000164"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void requestResponseDataMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestResponseData() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestResponseFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestResponseMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestResponseFrameCodec.data(request); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestStreamDataMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Integer.MAX_VALUE + 1L, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(Long.MAX_VALUE, actualRequest); + assertEquals("md", metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamData() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestStreamFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertNull(metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + ByteBuf data = RequestStreamFrameCodec.data(request); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfDataAndMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfData() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestFireAndForgetFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestFnfMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestFireAndForgetFrameCodec.data(request); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("md", metadata); + assertTrue(data.readableBytes() == 0); + request.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java index eeb950148..bc013e024 100644 --- a/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java +++ b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java @@ -1,52 +1,32 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package io.rsocket.frame; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufUtil; import io.netty.buffer.Unpooled; import java.nio.charset.StandardCharsets; -import org.junit.Test; - -public class KeepaliveFrameFlyweightTest { - private final ByteBuf byteBuf = Unpooled.buffer(1024); +import org.junit.jupiter.api.Test; +class KeepaliveFrameFlyweightTest { @Test - public void canReadData() { + void canReadData() { ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - int length = - KeepaliveFrameFlyweight.encode(byteBuf, KeepaliveFrameFlyweight.FLAGS_KEEPALIVE_R, data); - data.resetReaderIndex(); - - assertEquals( - KeepaliveFrameFlyweight.FLAGS_KEEPALIVE_R, - FrameHeaderFlyweight.flags(byteBuf) & KeepaliveFrameFlyweight.FLAGS_KEEPALIVE_R); - assertEquals(data, FrameHeaderFlyweight.sliceFrameData(byteBuf)); + ByteBuf frame = KeepAliveFrameCodec.encode(ByteBufAllocator.DEFAULT, true, 0, data); + assertTrue(KeepAliveFrameCodec.respondFlag(frame)); + assertEquals(data, KeepAliveFrameCodec.data(frame)); + frame.release(); } @Test - public void testEncoding() { - int encoded = - KeepaliveFrameFlyweight.encode( - byteBuf, - KeepaliveFrameFlyweight.FLAGS_KEEPALIVE_R, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - assertEquals("00000f000000000c80000000000000000064", ByteBufUtil.hexDump(byteBuf, 0, encoded)); + void testEncoding() { + ByteBuf frame = + KeepAliveFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 0, Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + assertEquals("00000f000000000c80000000000000000064", ByteBufUtil.hexDump(frame)); + frame.release(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java new file mode 100644 index 000000000..73c3bde5e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java @@ -0,0 +1,42 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class LeaseFrameCodecTest { + + @Test + void leaseMetadata() { + ByteBuf metadata = bytebuf("md"); + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, metadata); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertEquals(metadata, LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + @Test + void leaseAbsentMetadata() { + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, null); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertNull(LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameFlyweightTest.java deleted file mode 100644 index b943c6f20..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameFlyweightTest.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.frame; - -import static org.junit.Assert.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import java.nio.charset.StandardCharsets; -import org.junit.Test; - -public class LeaseFrameFlyweightTest { - private final ByteBuf byteBuf = Unpooled.buffer(1024); - - @Test - public void size() { - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - int length = LeaseFrameFlyweight.encode(byteBuf, 0, 0, metadata); - assertEquals(length, 9 + 4 * 2 + 4); // Frame header + ttl + #requests + 4 byte metadata - } - - @Test - public void testEncoding() { - int encoded = - LeaseFrameFlyweight.encode( - byteBuf, 0, 0, Unpooled.copiedBuffer("md", StandardCharsets.UTF_8)); - assertEquals( - "00001000000000090000000000000000006d64", ByteBufUtil.hexDump(byteBuf, 0, encoded)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java new file mode 100644 index 000000000..aecbb31ce --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java @@ -0,0 +1,88 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class PayloadFlyweightTest { + + @Test + void nextCompleteDataMetadata() { + Payload payload = DefaultPayload.create("d", "md"); + ByteBuf nextComplete = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); + Assertions.assertEquals("d", data); + Assertions.assertEquals("md", metadata); + nextComplete.release(); + } + + @Test + void nextCompleteData() { + Payload payload = DefaultPayload.create("d"); + ByteBuf nextComplete = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(nextComplete); + Assertions.assertEquals("d", data); + Assertions.assertNull(metadata); + nextComplete.release(); + } + + @Test + void nextCompleteMetaData() { + Payload payload = + DefaultPayload.create( + Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer("md".getBytes(StandardCharsets.UTF_8))); + + ByteBuf nextComplete = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + ByteBuf data = PayloadFrameCodec.data(nextComplete); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); + Assertions.assertTrue(data.readableBytes() == 0); + Assertions.assertEquals("md", metadata); + nextComplete.release(); + } + + @Test + void nextDataMetadata() { + Payload payload = DefaultPayload.create("d", "md"); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(next).toString(StandardCharsets.UTF_8); + Assertions.assertEquals("d", data); + Assertions.assertEquals("md", metadata); + next.release(); + } + + @Test + void nextData() { + Payload payload = DefaultPayload.create("d"); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); + Assertions.assertEquals("d", data); + Assertions.assertNull(metadata); + next.release(); + } + + @Test + void nextDataEmptyMetadata() { + Payload payload = DefaultPayload.create("d".getBytes(), new byte[0]); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); + Assertions.assertEquals("d", data); + Assertions.assertEquals(metadata.readableBytes(), 0); + next.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestFrameFlyweightTest.java deleted file mode 100644 index 828a75b79..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/RequestFrameFlyweightTest.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.frame; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.util.PayloadImpl; -import java.nio.charset.StandardCharsets; -import org.junit.Test; - -public class RequestFrameFlyweightTest { - private final ByteBuf byteBuf = Unpooled.buffer(1024); - - @Test - public void testEncoding() { - int encoded = - RequestFrameFlyweight.encode( - byteBuf, - 1, - FrameHeaderFlyweight.FLAGS_M, - FrameType.REQUEST_STREAM, - 1, - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - assertEquals( - "000010000000011900000000010000026d6464", ByteBufUtil.hexDump(byteBuf, 0, encoded)); - - PayloadImpl payload = - new PayloadImpl(Frame.from(stringToBuf("000010000000011900000000010000026d6464"))); - - assertEquals("md", StandardCharsets.UTF_8.decode(payload.getMetadata()).toString()); - } - - @Test - public void testEncodingWithEmptyMetadata() { - int encoded = - RequestFrameFlyweight.encode( - byteBuf, - 1, - FrameHeaderFlyweight.FLAGS_M, - FrameType.REQUEST_STREAM, - 1, - Unpooled.copiedBuffer("", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - assertEquals("00000e0000000119000000000100000064", ByteBufUtil.hexDump(byteBuf, 0, encoded)); - - PayloadImpl payload = - new PayloadImpl(Frame.from(stringToBuf("00000e0000000119000000000100000064"))); - - assertEquals("", StandardCharsets.UTF_8.decode(payload.getMetadata()).toString()); - } - - @Test - public void testEncodingWithNullMetadata() { - int encoded = - RequestFrameFlyweight.encode( - byteBuf, - 1, - 0, - FrameType.REQUEST_STREAM, - 1, - null, - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - assertEquals("00000b0000000118000000000164", ByteBufUtil.hexDump(byteBuf, 0, encoded)); - - PayloadImpl payload = new PayloadImpl(Frame.from(stringToBuf("00000b0000000118000000000164"))); - - assertFalse(payload.hasMetadata()); - } - - private String bufToString(int encoded) { - return ByteBufUtil.hexDump(byteBuf, 0, encoded); - } - - private ByteBuf stringToBuf(CharSequence s) { - return Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(s)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java new file mode 100644 index 000000000..e38258040 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java @@ -0,0 +1,19 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import org.junit.jupiter.api.Test; + +class RequestNFrameCodecTest { + @Test + void testEncoding() { + ByteBuf frame = RequestNFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, 5); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + assertEquals("00000a00000001200000000005", ByteBufUtil.hexDump(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java deleted file mode 100644 index 8df28a7be..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameFlyweightTest.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.frame; - -import static org.junit.Assert.assertEquals; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import org.junit.Test; - -public class RequestNFrameFlyweightTest { - private final ByteBuf byteBuf = Unpooled.buffer(1024); - - @Test - public void testEncoding() { - int encoded = RequestNFrameFlyweight.encode(byteBuf, 1, 5); - assertEquals("00000a00000001200000000005", ByteBufUtil.hexDump(byteBuf, 0, encoded)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java new file mode 100644 index 000000000..4815bfb8e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +public class ResumeFrameCodecTest { + + @Test + void testEncoding() { + byte[] tokenBytes = new byte[65000]; + Arrays.fill(tokenBytes, (byte) 1); + ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); + ByteBuf byteBuf = ResumeFrameCodec.encode(ByteBufAllocator.DEFAULT, token, 21, 12); + assertThat(ResumeFrameCodec.version(byteBuf)).isEqualTo(ResumeFrameCodec.CURRENT_VERSION); + assertThat(ResumeFrameCodec.token(byteBuf)).isEqualTo(token); + assertThat(ResumeFrameCodec.lastReceivedServerPos(byteBuf)).isEqualTo(21); + assertThat(ResumeFrameCodec.firstAvailableClientPos(byteBuf)).isEqualTo(12); + byteBuf.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java new file mode 100644 index 000000000..b818d579d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java @@ -0,0 +1,17 @@ +package io.rsocket.frame; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.junit.jupiter.api.Test; + +public class ResumeOkFrameCodecTest { + + @Test + public void testEncoding() { + ByteBuf byteBuf = ResumeOkFrameCodec.encode(ByteBufAllocator.DEFAULT, 42); + assertThat(ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)).isEqualTo(42); + byteBuf.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java new file mode 100644 index 000000000..3317b4618 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java @@ -0,0 +1,57 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +class SetupFrameCodecTest { + @Test + void testEncodingNoResume() { + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, false, 5, 500, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertFalse(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(0, SetupFrameCodec.resumeToken(frame).readableBytes()); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(payload.metadata(), SetupFrameCodec.metadata(frame)); + assertEquals(payload.data(), SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } + + @Test + void testEncodingResume() { + byte[] tokenBytes = new byte[65000]; + Arrays.fill(tokenBytes, (byte) 1); + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 5, 500, token, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertTrue(SetupFrameCodec.honorLease(frame)); + assertTrue(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(token, SetupFrameCodec.resumeToken(frame)); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(payload.metadata(), SetupFrameCodec.metadata(frame)); + assertEquals(payload.data(), SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java deleted file mode 100644 index 8dd5f7a8b..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameFlyweightTest.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.frame; - -import static org.junit.Assert.*; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; -import io.rsocket.FrameType; -import java.nio.charset.StandardCharsets; -import org.junit.Test; - -public class SetupFrameFlyweightTest { - private final ByteBuf byteBuf = Unpooled.buffer(1024); - - @Test - public void validFrame() { - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - SetupFrameFlyweight.encode(byteBuf, 0, 5, 500, "metadata_type", "data_type", metadata, data); - - metadata.resetReaderIndex(); - data.resetReaderIndex(); - - assertEquals(FrameType.SETUP, FrameHeaderFlyweight.frameType(byteBuf)); - assertEquals("metadata_type", SetupFrameFlyweight.metadataMimeType(byteBuf)); - assertEquals("data_type", SetupFrameFlyweight.dataMimeType(byteBuf)); - assertEquals(metadata, FrameHeaderFlyweight.sliceFrameMetadata(byteBuf)); - assertEquals(data, FrameHeaderFlyweight.sliceFrameData(byteBuf)); - } - - @Test(expected = IllegalArgumentException.class) - public void resumeNotSupported() { - SetupFrameFlyweight.encode( - byteBuf, - SetupFrameFlyweight.FLAGS_RESUME_ENABLE, - 5, - 500, - "", - "", - Unpooled.EMPTY_BUFFER, - Unpooled.EMPTY_BUFFER); - } - - @Test - public void validResumeFrame() { - ByteBuf token = Unpooled.wrappedBuffer(new byte[] {2, 3}); - ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); - ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); - SetupFrameFlyweight.encode( - byteBuf, - SetupFrameFlyweight.FLAGS_RESUME_ENABLE, - 5, - 500, - token, - "metadata_type", - "data_type", - metadata, - data); - - token.resetReaderIndex(); - metadata.resetReaderIndex(); - data.resetReaderIndex(); - - assertEquals(FrameType.SETUP, FrameHeaderFlyweight.frameType(byteBuf)); - assertEquals("metadata_type", SetupFrameFlyweight.metadataMimeType(byteBuf)); - assertEquals("data_type", SetupFrameFlyweight.dataMimeType(byteBuf)); - assertEquals(metadata, FrameHeaderFlyweight.sliceFrameMetadata(byteBuf)); - assertEquals(data, FrameHeaderFlyweight.sliceFrameData(byteBuf)); - assertEquals( - SetupFrameFlyweight.FLAGS_RESUME_ENABLE, - FrameHeaderFlyweight.flags(byteBuf) & SetupFrameFlyweight.FLAGS_RESUME_ENABLE); - } - - @Test - public void testEncoding() { - int encoded = - SetupFrameFlyweight.encode( - byteBuf, - 0, - 5000, - 60000, - "mdmt", - "dmt", - Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), - Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); - assertEquals( - "00002100000000050000010000000013880000ea60046d646d7403646d740000026d6464", - ByteBufUtil.hexDump(byteBuf, 0, encoded)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java new file mode 100644 index 000000000..be7fb837b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +public class VersionCodecTest { + @Test + public void simple() { + int version = VersionCodec.encode(1, 0); + assertEquals(1, VersionCodec.major(version)); + assertEquals(0, VersionCodec.minor(version)); + assertEquals(0x00010000, version); + assertEquals("1.0", VersionCodec.toString(version)); + } + + @Test + public void complex() { + int version = VersionCodec.encode(0x1234, 0x5678); + assertEquals(0x1234, VersionCodec.major(version)); + assertEquals(0x5678, VersionCodec.minor(version)); + assertEquals(0x12345678, version); + assertEquals("4660.22136", VersionCodec.toString(version)); + } + + @Test + public void noShortOverflow() { + int version = VersionCodec.encode(43210, 43211); + assertEquals(43210, VersionCodec.major(version)); + assertEquals(43211, VersionCodec.minor(version)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java deleted file mode 100644 index 93a341735..000000000 --- a/rsocket-core/src/test/java/io/rsocket/frame/VersionFlyweightTest.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.frame; - -import static org.junit.Assert.*; - -import org.junit.Test; - -public class VersionFlyweightTest { - @Test - public void simple() { - int version = VersionFlyweight.encode(1, 0); - assertEquals(1, VersionFlyweight.major(version)); - assertEquals(0, VersionFlyweight.minor(version)); - assertEquals(0x00010000, version); - assertEquals("1.0", VersionFlyweight.toString(version)); - } - - @Test - public void complex() { - int version = VersionFlyweight.encode(0x1234, 0x5678); - assertEquals(0x1234, VersionFlyweight.major(version)); - assertEquals(0x5678, VersionFlyweight.minor(version)); - assertEquals(0x12345678, version); - assertEquals("4660.22136", VersionFlyweight.toString(version)); - } - - @Test - public void noShortOverflow() { - int version = VersionFlyweight.encode(43210, 43211); - assertEquals(43210, VersionFlyweight.major(version)); - assertEquals(43211, VersionFlyweight.minor(version)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/old/LeaseFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/old/LeaseFrameFlyweightTest.java new file mode 100644 index 000000000..ef4fcc6b0 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/old/LeaseFrameFlyweightTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame.old; + +public class LeaseFrameFlyweightTest { + /*private final ByteBuf byteBuf = Unpooled.buffer(1024); + + @Test + public void size() { + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + int length = LeaseFrameFlyweight.encode(byteBuf, 0, 0, metadata); + assertEquals(length, 9 + 4 * 2 + 4); // Frame header + ttl + #requests + 4 byte metadata + } + + @Test + public void testEncoding() { + int encoded = + LeaseFrameFlyweight.encode( + byteBuf, 0, 0, Unpooled.copiedBuffer("md", StandardCharsets.UTF_8)); + assertEquals( + "00001000000000090000000000000000006d64", ByteBufUtil.hexDump(byteBuf, 0, encoded)); + }*/ +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java deleted file mode 100644 index ca576ff03..000000000 --- a/rsocket-core/src/test/java/io/rsocket/internal/ClientServerInputMultiplexerTest.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.internal; - -import static org.junit.Assert.assertEquals; - -import io.rsocket.Frame; -import io.rsocket.plugins.PluginRegistry; -import io.rsocket.test.util.TestDuplexConnection; -import java.util.concurrent.atomic.AtomicInteger; -import org.junit.Before; -import org.junit.Test; - -public class ClientServerInputMultiplexerTest { - private TestDuplexConnection source; - private ClientServerInputMultiplexer multiplexer; - - @Before - public void setup() { - source = new TestDuplexConnection(); - multiplexer = new ClientServerInputMultiplexer(source, new PluginRegistry()); - } - - @Test - public void testSplits() { - AtomicInteger clientFrames = new AtomicInteger(); - AtomicInteger serverFrames = new AtomicInteger(); - AtomicInteger connectionFrames = new AtomicInteger(); - - multiplexer - .asClientConnection() - .receive() - .doOnNext(f -> clientFrames.incrementAndGet()) - .subscribe(); - multiplexer - .asServerConnection() - .receive() - .doOnNext(f -> serverFrames.incrementAndGet()) - .subscribe(); - multiplexer - .asStreamZeroConnection() - .receive() - .doOnNext(f -> connectionFrames.incrementAndGet()) - .subscribe(); - - source.addToReceivedBuffer(Frame.Error.from(1, new Exception())); - assertEquals(1, clientFrames.get()); - assertEquals(0, serverFrames.get()); - assertEquals(0, connectionFrames.get()); - - source.addToReceivedBuffer(Frame.Error.from(2, new Exception())); - assertEquals(1, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(0, connectionFrames.get()); - - source.addToReceivedBuffer(Frame.Error.from(1, new Exception())); - assertEquals(2, clientFrames.get()); - assertEquals(1, serverFrames.get()); - assertEquals(0, connectionFrames.get()); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java new file mode 100644 index 000000000..d73f92b85 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java @@ -0,0 +1,23 @@ +package io.rsocket.internal; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import reactor.core.scheduler.Scheduler; + +public class SchedulerUtils { + + public static void warmup(Scheduler scheduler) throws InterruptedException { + warmup(scheduler, 10000); + } + + public static void warmup(Scheduler scheduler, int times) throws InterruptedException { + scheduler.start(); + + // warm up + CountDownLatch latch = new CountDownLatch(times); + for (int i = 0; i < times; i++) { + scheduler.schedule(latch::countDown); + } + latch.await(5, TimeUnit.SECONDS); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java new file mode 100644 index 000000000..343a93beb --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -0,0 +1,366 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.time.Duration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class UnboundedProcessorTest { + + @BeforeAll + public static void setup() { + Hooks.onErrorDropped(__ -> {}); + } + + @AfterAll + public static void teardown() { + Hooks.resetOnErrorDropped(); + } + + @ParameterizedTest( + name = + "Test that emitting {0} onNext before subscribe and requestN should deliver all the signals once the subscriber is available") + @ValueSource(ints = {10, 100, 10_000, 100_000, 1_000_000, 10_000_000}) + public void testOnNextBeforeSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); + + for (int i = 0; i < n; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + processor.onComplete(); + + StepVerifier.create(processor.count()).expectNext(Long.valueOf(n)).verifyComplete(); + } + + @ParameterizedTest( + name = + "Test that emitting {0} onNext after subscribe and requestN should deliver all the signals") + @ValueSource(ints = {10, 100, 10_000}) + public void testOnNextAfterSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + processor.subscribe(assertSubscriber); + + for (int i = 0; i < n; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + assertSubscriber.awaitAndAssertNextValueCount(n); + } + + @ParameterizedTest( + name = + "Test that prioritized value sending deliver prioritized signals before the others mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void testPrioritizedSending(boolean fusedCase) { + UnboundedProcessor processor = new UnboundedProcessor(); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + processor.onNextPrioritized(Unpooled.copiedBuffer("test", CharsetUtil.UTF_8)); + + assertThat(fusedCase ? processor.poll() : processor.next().block()) + .isNotNull() + .extracting(bb -> bb.toString(CharsetUtil.UTF_8)) + .isEqualTo("test"); + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensureUnboundedProcessorDisposesQueueProperly(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + }, + unboundedProcessor::dispose, + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest1(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest2(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> { + unboundedProcessor.subscribe(assertSubscriber); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest3(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + assertSubscriber.cancel(); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber)); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest31(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber), + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }, + assertSubscriber::cancel); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext + dispose | downstream async drain should not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensuresAsyncFusionAndDisposureHasNoDeadlock(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + final ByteBuf buffer5 = allocator.buffer(5); + final ByteBuf buffer6 = allocator.buffer(6); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber() + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + unboundedProcessor.onNext(buffer3); + unboundedProcessor.onNext(buffer4); + unboundedProcessor.onNext(buffer5); + unboundedProcessor.onNext(buffer6); + unboundedProcessor.dispose(); + }, + unboundedProcessor::dispose); + + assertSubscriber.await(Duration.ofSeconds(50)).values().forEach(ReferenceCountUtil::release); + } + + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java new file mode 100644 index 000000000..b6eac9835 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -0,0 +1,1277 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.subscriber; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BooleanSupplier; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.context.Context; + +/** + * A Subscriber implementation that hosts assertion tests for its state and allows asynchronous + * cancellation and requesting. + * + *

To create a new instance of {@link AssertSubscriber}, you have the choice between these static + * methods: + * + *

    + *
  • {@link AssertSubscriber#create()}: create a new {@link AssertSubscriber} and requests an + * unbounded number of elements. + *
  • {@link AssertSubscriber#create(long)}: create a new {@link AssertSubscriber} and requests + * {@code n} elements (can be 0 if you want no initial demand). + *
+ * + *

If you are testing asynchronous publishers, don't forget to use one of the {@code await*()} + * methods to wait for the data to assert. + * + *

You can extend this class but only the onNext, onError and onComplete can be overridden. You + * can call {@link #request(long)} and {@link #cancel()} from any thread or from within the + * overridable methods but you should avoid calling the assertXXX methods asynchronously. + * + *

Usage: + * + *

{@code
+ * AssertSubscriber
+ *   .subscribe(publisher)
+ *   .await()
+ *   .assertValues("ABC", "DEF");
+ * }
+ * + * @param the value type. + * @author Sebastien Deleuze + * @author David Karnok + * @author Anatoly Kadyshev + * @author Stephane Maldini + * @author Brian Clozel + */ +public class AssertSubscriber implements CoreSubscriber, Subscription, Scannable { + + /** Default timeout for waiting next values to be received */ + public static final Duration DEFAULT_VALUES_TIMEOUT = Duration.ofSeconds(3); + + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(AssertSubscriber.class, "wip"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater NEXT_VALUES = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, Subscription.class, "s"); + + private final Context context; + + private final List errors = new LinkedList<>(); + + private final CountDownLatch cdl = new CountDownLatch(1); + + volatile boolean done; + + volatile Subscription s; + + volatile long requested; + + volatile int wip; + + volatile List values = new LinkedList<>(); + + /** The fusion mode to request. */ + private int requestedFusionMode = -1; + + /** The established fusion mode. */ + private volatile int establishedFusionMode = -1; + + /** The fuseable QueueSubscription in case a fusion mode was specified. */ + private Fuseable.QueueSubscription qs; + + private int subscriptionCount = 0; + + private int completionCount = 0; + + private volatile long valueCount = 0L; + + private volatile long nextValueAssertedCount = 0L; + + private Duration valuesTimeout = DEFAULT_VALUES_TIMEOUT; + + private boolean valuesStorage = true; + + // + // ============================================================================================================== + // Static methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throws an {@link AssertionError} with the specified error message + * supplier. + * + * @param timeout the timeout duration + * @param errorMessageSupplier the error message supplier + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, Supplier errorMessageSupplier, BooleanSupplier conditionSupplier) { + + Objects.requireNonNull(errorMessageSupplier); + Objects.requireNonNull(conditionSupplier); + Objects.requireNonNull(timeout); + + long timeoutNs = timeout.toNanos(); + long startTime = System.nanoTime(); + do { + if (conditionSupplier.getAsBoolean()) { + return; + } + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } while (System.nanoTime() - startTime < timeoutNs); + throw new AssertionError(errorMessageSupplier.get()); + } + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throw an {@link AssertionError} with the specified error message. + * + * @param timeout the timeout duration + * @param errorMessage the error message + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, final String errorMessage, BooleanSupplier conditionSupplier) { + await( + timeout, + new Supplier() { + @Override + public String get() { + return errorMessage; + } + }, + conditionSupplier); + } + + /** + * Create a new {@link AssertSubscriber} that requests an unbounded number of elements. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create() { + return new AssertSubscriber<>(); + } + + /** + * Create a new {@link AssertSubscriber} that requests initially {@code n} elements. You can then + * manage the demand with {@link Subscription#request(long)}. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param n Number of elements to request (can be 0 if you want no initial demand). + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create(long n) { + return new AssertSubscriber<>(n); + } + + // + // ============================================================================================================== + // constructors + // + // ============================================================================================================== + + public AssertSubscriber() { + this(Context.empty(), Long.MAX_VALUE); + } + + public AssertSubscriber(long n) { + this(Context.empty(), n); + } + + public AssertSubscriber(Context context) { + this(context, Long.MAX_VALUE); + } + + public AssertSubscriber(Context context, long n) { + if (n < 0) { + throw new IllegalArgumentException("initialRequest >= required but it was " + n); + } + this.context = context; + REQUESTED.lazySet(this, n); + } + + // + // ============================================================================================================== + // Configuration + // + // ============================================================================================================== + + /** + * Enable or disabled the values storage. It is enabled by default, and can be disable in order to + * be able to perform performance benchmarks or tests with a huge amount values. + * + * @param enabled enable value storage? + * @return this + */ + public final AssertSubscriber configureValuesStorage(boolean enabled) { + this.valuesStorage = enabled; + return this; + } + + /** + * Configure the timeout in seconds for waiting next values to be received (3 seconds by default). + * + * @param timeout the new default value timeout duration + * @return this + */ + public final AssertSubscriber configureValuesTimeout(Duration timeout) { + this.valuesTimeout = timeout; + return this; + } + + /** + * Returns the established fusion mode or -1 if it was not enabled + * + * @return the fusion mode, see Fuseable constants + */ + public final int establishedFusionMode() { + return establishedFusionMode; + } + + // + // ============================================================================================================== + // Assertions + // + // ============================================================================================================== + + /** + * Assert a complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertComplete() { + assertNoError(); + int c = completionCount; + if (c == 0) { + throw new AssertionError("Not completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert the specified values have been received. Values storage should be enabled to use this + * method. + * + * @param expectedValues the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertContainValues(Set expectedValues) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + if (expectedValues.size() > values.size()) { + throw new AssertionError("Actual contains fewer elements" + values, null); + } + + Iterator expected = expectedValues.iterator(); + + for (; ; ) { + boolean n2 = expected.hasNext(); + if (n2) { + T t2 = expected.next(); + if (!values.contains(t2)) { + throw new AssertionError( + "The element is not contained in the " + + "received results" + + " = " + + valueAndClass(t2), + null); + } + } else { + break; + } + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertError() { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @param clazz The class of the exception contained in the error signal + * @return this + */ + public final AssertSubscriber assertError(Class clazz) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + Throwable e = errors.get(0); + if (!clazz.isInstance(e)) { + throw new AssertionError( + "Error class incompatible: expected = " + clazz + ", actual = " + e, null); + } + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + errors, null); + } + return this; + } + + public final AssertSubscriber assertErrorMessage(String message) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + assertionError("No error", null); + } + if (s == 1) { + if (!Objects.equals(message, errors.get(0).getMessage())) { + assertionError( + "Error class incompatible: expected = \"" + + message + + "\", actual = \"" + + errors.get(0).getMessage() + + "\"", + null); + } + } + if (s > 1) { + assertionError("Multiple errors: " + s, null); + } + + return this; + } + + /** + * Assert an error signal has been received. + * + * @param expectation A method that can verify the exception contained in the error signal and + * throw an exception (like an {@link AssertionError}) if the exception is not valid. + * @return this + */ + public final AssertSubscriber assertErrorWith(Consumer expectation) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + expectation.accept(errors.get(0)); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert that the upstream was a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertFuseableSource() { + if (qs == null) { + throw new AssertionError("Upstream was not Fuseable"); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionEnabled() { + if (establishedFusionMode != Fuseable.SYNC && establishedFusionMode != Fuseable.ASYNC) { + throw new AssertionError("Fusion was not enabled"); + } + return this; + } + + public final AssertSubscriber assertFusionMode(int expectedMode) { + if (establishedFusionMode != expectedMode) { + throw new AssertionError( + "Wrong fusion mode: expected: " + + fusionModeName(expectedMode) + + ", actual: " + + fusionModeName(establishedFusionMode)); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionRejected() { + if (establishedFusionMode != Fuseable.NONE) { + throw new AssertionError("Fusion was granted"); + } + return this; + } + + /** + * Assert no error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNoError() { + int s = errors.size(); + if (s == 1) { + Throwable e = errors.get(0); + String valueAndClass = e == null ? null : e + " (" + e.getClass().getSimpleName() + ")"; + throw new AssertionError("Error present: " + valueAndClass, null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert no values have been received. + * + * @return this + */ + public final AssertSubscriber assertNoValues() { + if (valueCount != 0) { + throw new AssertionError( + "No values expected but received: [length = " + values.size() + "] " + values, null); + } + return this; + } + + /** + * Assert that the upstream was not a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertNonFuseableSource() { + if (qs != null) { + throw new AssertionError("Upstream was Fuseable"); + } + return this; + } + + /** + * Assert no complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotComplete() { + int c = completionCount; + if (c == 1) { + throw new AssertionError("Completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert no subscription occurred. + * + * @return this + */ + public final AssertSubscriber assertNotSubscribed() { + int s = subscriptionCount; + + if (s == 1) { + throw new AssertionError("OnSubscribe called once", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert no complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotTerminated() { + if (cdl.getCount() == 0) { + throw new AssertionError("Terminated", null); + } + return this; + } + + /** + * Assert subscription occurred (once). + * + * @return this + */ + public final AssertSubscriber assertSubscribed() { + int s = subscriptionCount; + + if (s == 0) { + throw new AssertionError("OnSubscribe not called", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert either complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertTerminated() { + if (cdl.getCount() != 0) { + throw new AssertionError("Not terminated", null); + } + return this; + } + + /** + * Assert {@code n} values has been received. + * + * @param n the expected value count + * @return this + */ + public final AssertSubscriber assertValueCount(long n) { + if (valueCount != n) { + throw new AssertionError( + "Different value count: expected = " + n + ", actual = " + valueCount, null); + } + return this; + } + + /** + * Assert the specified values have been received in the same order read by the passed {@link + * Iterable}. Values storage should be enabled to use this method. + * + * @param expectedSequence the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertValueSequence(Iterable expectedSequence) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + Iterator actual = values.iterator(); + Iterator expected = expectedSequence.iterator(); + int i = 0; + for (; ; ) { + boolean n1 = actual.hasNext(); + boolean n2 = expected.hasNext(); + if (n1 && n2) { + T t1 = actual.next(); + T t2 = expected.next(); + if (!Objects.equals(t1, t2)) { + throw new AssertionError( + "The element with index " + + i + + " does not match: expected = " + + valueAndClass(t2) + + ", actual = " + + valueAndClass(t1), + null); + } + i++; + } else if (n1 && !n2) { + throw new AssertionError("Actual contains more elements" + values, null); + } else if (!n1 && n2) { + throw new AssertionError("Actual contains fewer elements: " + values, null); + } else { + break; + } + } + return this; + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectedValues the values to assert + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValues(T... expectedValues) { + return assertValueSequence(Arrays.asList(expectedValues)); + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValuesWith(Consumer... expectations) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + final int expectedValueCount = expectations.length; + if (expectedValueCount != values.size()) { + throw new AssertionError( + "Different value count: expected = " + expectedValueCount + ", actual = " + valueCount, + null); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = values.get(i); + consumer.accept(actualValue); + } + return this; + } + + // + // ============================================================================================================== + // Await methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until a complete successfully or error signal is received. + * + * @return this + */ + public final AssertSubscriber await() { + if (cdl.getCount() == 0) { + return this; + } + try { + cdl.await(); + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + return this; + } + + /** + * Blocking method that waits until a complete successfully or error signal is received or until a + * timeout occurs. + * + * @param timeout The timeout value + * @return this + */ + public final AssertSubscriber await(Duration timeout) { + if (cdl.getCount() == 0) { + return this; + } + try { + if (!cdl.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new AssertionError("No complete or error signal before timeout"); + } + return this; + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + } + + /** + * Blocking method that waits until {@code n} next values have been received. + * + * @param n the value count to assert + * @return this + */ + public final AssertSubscriber awaitAndAssertNextValueCount(final long n) { + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + n, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d", + valueCount - nextValueAssertedCount, n, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + n)); + nextValueAssertedCount += n; + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * values provided) to assert them. + * + * @param values the values to assert + * @return this + */ + @SafeVarargs + @SuppressWarnings("unchecked") + public final AssertSubscriber awaitAndAssertNextValues(T... values) { + final int expectedNum = values.length; + final List> expectations = new ArrayList<>(); + for (int i = 0; i < expectedNum; i++) { + final T expectedValue = values[i]; + expectations.add( + actualValue -> { + if (!actualValue.equals(expectedValue)) { + throw new AssertionError( + String.format( + "Expected Next signal: %s, but got: %s", expectedValue, actualValue)); + } + }); + } + awaitAndAssertNextValuesWith(expectations.toArray((Consumer[]) new Consumer[0])); + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * expectations provided) to assert them. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + */ + @SafeVarargs + public final AssertSubscriber awaitAndAssertNextValuesWith(Consumer... expectations) { + valuesStorage = true; + final int expectedValueCount = expectations.length; + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + expectedValueCount, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d ms", + valueCount - nextValueAssertedCount, expectedValueCount, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + expectedValueCount)); + List nextValuesSnapshot; + List empty = new ArrayList<>(); + for (; ; ) { + nextValuesSnapshot = values; + if (NEXT_VALUES.compareAndSet(this, values, empty)) { + break; + } + } + if (nextValuesSnapshot.size() < expectedValueCount) { + throw new AssertionError( + String.format( + "Expected %d number of signals but received %d", + expectedValueCount, nextValuesSnapshot.size())); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = nextValuesSnapshot.get(i); + consumer.accept(actualValue); + } + nextValueAssertedCount += expectedValueCount; + return this; + } + + // + // ============================================================================================================== + // Overrides + // + // ============================================================================================================== + + @Override + public void cancel() { + Subscription a = s; + if (a != Operators.cancelledSubscription()) { + a = S.getAndSet(this, Operators.cancelledSubscription()); + if (a != null && a != Operators.cancelledSubscription()) { + a.cancel(); + + if (establishedFusionMode == Fuseable.ASYNC) { + final int previousState = markWorkAdded(); + if (!isWorkInProgress(previousState)) { + clearAndFinalize(); + } + } + } + } + } + + final boolean isCancelled() { + return s == Operators.cancelledSubscription(); + } + + public final boolean isTerminated() { + return cdl.getCount() == 0; + } + + @Override + public void onComplete() { + done = true; + completionCount++; + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + cdl.countDown(); + } + + @Override + public void onError(Throwable t) { + done = true; + errors.add(t); + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + cdl.countDown(); + } + + @Override + public void onNext(T t) { + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return state > 0; + } + + int markWorkAdded() { + for (; ; ) { + int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + if ((state & Integer.MAX_VALUE) == Integer.MAX_VALUE) { + return state; + } + int nextState = state + 1; + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + void clearAndFinalize() { + final Fuseable.QueueSubscription qs = this.qs; + for (; ; ) { + int state = this.wip; + + qs.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + void drain() { + final int previousState = markWorkAdded(); + if (isWorkInProgress(previousState)) { + return; + } + + if (isFinalized(previousState)) { + qs.clear(); + return; + } + + T t; + int m = 1; + for (; ; ) { + if (isCancelled()) { + clearAndFinalize(); + break; + } + boolean done = this.done; + t = qs.poll(); + if (t == null) { + if (done) { + clearAndFinalize(); + cdl.countDown(); + return; + } + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + continue; + } + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + subscriptionCount++; + int requestMode = requestedFusionMode; + if (requestMode >= 0) { + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; + + int m = qs.requestFusion(requestMode); + establishedFusionMode = m; + + if (!setWithoutRequesting(s)) { + qs.clear(); + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + return; + } + + if (m == Fuseable.SYNC) { + for (; ; ) { + T v = qs.poll(); + if (v == null) { + onComplete(); + break; + } + + onNext(v); + } + } else { + requestDeferred(); + } + + return; + } + } + + if (!set(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (establishedFusionMode != Fuseable.SYNC) { + normalRequest(n); + } + } + } + + @Override + @NonNull + public Context currentContext() { + return context; + } + + /** + * Setup what fusion mode should be requested from the incoming Subscription if it happens to be + * QueueSubscription + * + * @param requestMode the mode to request, see Fuseable constants + * @return this + */ + public final AssertSubscriber requestedFusionMode(int requestMode) { + this.requestedFusionMode = requestMode; + return this; + } + + public Subscription upstream() { + return s; + } + + // + // ============================================================================================================== + // Non public methods + // + // ============================================================================================================== + + protected final void normalRequest(long n) { + Subscription a = s; + if (a != null) { + a.request(n); + } else { + Operators.addCap(REQUESTED, this, n); + + a = s; + + if (a != null) { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + a.request(r); + } + } + } + } + + /** Requests the deferred amount if not zero. */ + protected final void requestDeferred() { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + } + + /** + * Atomically sets the single subscription and requests the missed amount from it. + * + * @param s + * @return false if this arbiter is cancelled or there was a subscription already set + */ + protected final boolean set(Subscription s) { + Objects.requireNonNull(s, "s"); + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + + return true; + } + + a = this.s; + + if (a != Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + + Operators.reportSubscriptionSet(); + return false; + } + + /** + * Sets the Subscription once but does not request anything. + * + * @param s the Subscription to set + * @return true if successful, false if the current subscription is not null + */ + protected final boolean setWithoutRequesting(Subscription s) { + Objects.requireNonNull(s, "s"); + for (; ; ) { + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + return true; + } + } + } + + /** + * Prepares and throws an AssertionError exception based on the message, cause, the active state + * and the potential errors so far. + * + * @param message the message + * @param cause the optional Throwable cause + * @throws AssertionError as expected + */ + protected final void assertionError(String message, Throwable cause) { + StringBuilder b = new StringBuilder(); + + if (cdl.getCount() != 0) { + b.append("(active) "); + } + b.append(message); + + List err = errors; + if (!err.isEmpty()) { + b.append(" (+ ").append(err.size()).append(" errors)"); + } + AssertionError e = new AssertionError(b.toString(), cause); + + for (Throwable t : err) { + e.addSuppressed(t); + } + + throw e; + } + + protected final String fusionModeName(int mode) { + switch (mode) { + case -1: + return "Disabled"; + case Fuseable.NONE: + return "None"; + case Fuseable.SYNC: + return "Sync"; + case Fuseable.ASYNC: + return "Async"; + default: + return "Unknown(" + mode + ")"; + } + } + + protected final String valueAndClass(Object o) { + if (o == null) { + return null; + } + return o + " (" + o.getClass().getSimpleName() + ")"; + } + + public List values() { + return values; + } + + public List errors() { + return errors; + } + + public final AssertSubscriber assertNoEvents() { + return assertNoValues().assertNoError().assertNotComplete(); + } + + @SafeVarargs + public final AssertSubscriber assertIncomplete(T... values) { + return assertValues(values).assertNotComplete().assertNoError(); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) { + return upstream(); + } + + boolean t = isTerminated(); + if (key == Attr.TERMINATED) return t; + if (key == Attr.ERROR) return (!errors.isEmpty() ? errors.get(0) : null); + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + if (key == Attr.CANCELLED) return isCancelled(); + if (key == Attr.RUN_STYLE) return Attr.RunStyle.SYNC; + + return null; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java new file mode 100644 index 000000000..9ebca34f7 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.lease; + +public class LeaseImplTest { + // + // @Test + // public void emptyLeaseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // Assertions.assertTrue(empty.isEmpty()); + // Assertions.assertFalse(empty.isValid()); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void emptyLeaseUseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // boolean success = empty.use(); + // assertFalse(success); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void leaseAvailability() { + // LeaseImpl lease = LeaseImpl.create(2, 100, Unpooled.EMPTY_BUFFER); + // Assertions.assertEquals(1.0, lease.availability(), 1e-5); + // } + // + // @Test + // public void leaseUseDecreasesAvailability() { + // LeaseImpl lease = LeaseImpl.create(30_000, 2, Unpooled.EMPTY_BUFFER); + // boolean success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.5, lease.availability(), 1e-5); + // Assertions.assertTrue(lease.isValid()); + // success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.0, lease.availability(), 1e-5); + // Assertions.assertFalse(lease.isValid()); + // Assertions.assertEquals(0, lease.getAllowedRequests()); + // success = lease.use(); + // Assertions.assertFalse(success); + // } + // + // @Test + // public void leaseTimeout() { + // int numberOfRequests = 1; + // LeaseImpl lease = LeaseImpl.create(1, numberOfRequests, Unpooled.EMPTY_BUFFER); + // Mono.delay(Duration.ofMillis(100)).block(); + // boolean success = lease.use(); + // Assertions.assertFalse(success); + // Assertions.assertTrue(lease.isExpired()); + // Assertions.assertEquals(numberOfRequests, lease.getAllowedRequests()); + // Assertions.assertFalse(lease.isValid()); + // } + // + // @Test + // public void useLeaseChangesAllowedRequests() { + // int numberOfRequests = 2; + // LeaseImpl lease = LeaseImpl.create(30_000, numberOfRequests, Unpooled.EMPTY_BUFFER); + // lease.use(); + // assertEquals(numberOfRequests - 1, lease.getAllowedRequests()); + // } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java new file mode 100644 index 000000000..a35e89391 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java @@ -0,0 +1,94 @@ +package io.rsocket.loadbalance; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +@ExtendWith(MockitoExtension.class) +class LoadbalanceRSocketClientTest { + + @Mock private ClientTransport clientTransport; + @Mock private RSocketConnector rSocketConnector; + + public static final Duration SHORT_DURATION = Duration.ofMillis(25); + public static final Duration LONG_DURATION = Duration.ofMillis(75); + + private static final Publisher SOURCE = + Flux.interval(SHORT_DURATION) + .onBackpressureBuffer() + .map(String::valueOf) + .map(DefaultPayload::create); + + private static final Mono PROGRESSING_HANDLER = + Mono.just( + new RSocket() { + private final AtomicInteger i = new AtomicInteger(); + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .delayElements(SHORT_DURATION) + .map(Payload::getDataUtf8) + .map(DefaultPayload::create) + .take(i.incrementAndGet()); + } + }); + + @Test + void testChannelReconnection() { + when(rSocketConnector.connect(clientTransport)).thenReturn(PROGRESSING_HANDLER); + + RSocketClient client = + LoadbalanceRSocketClient.create( + rSocketConnector, + Mono.just(singletonList(LoadbalanceTarget.from("key", clientTransport)))); + + Publisher result = + client + .requestChannel(SOURCE) + .repeatWhen(longFlux -> longFlux.delayElements(LONG_DURATION).take(5)) + .map(Payload::getDataUtf8) + .log(); + + StepVerifier.create(result) + .expectSubscription() + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("4")) + .verifyComplete(); + + verify(rSocketConnector).connect(clientTransport); + verifyNoMoreInteractions(rSocketConnector, clientTransport); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java new file mode 100644 index 000000000..c1b509297 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java @@ -0,0 +1,470 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import io.rsocket.util.RSocketProxy; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; + +public class LoadbalanceTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverAllTheRequestsWithRoundRobinStrategy() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = new TestClientTransport(); + final RSocket rSocket = + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }; + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(new TestRSocket(rSocket))); + + final List collectionOfDestination1 = + Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)); + final List collectionOfDestination2 = + Collections.singletonList(LoadbalanceTarget.from("2", mockTransport)); + final List collectionOfDestinations1And2 = + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), LoadbalanceTarget.from("2", mockTransport)); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, source.asFlux(), new RoundRobinLoadbalanceStrategy()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void shouldDeliverAllTheRequestsWithWeightedStrategy() throws InterruptedException { + final AtomicInteger counter = new AtomicInteger(); + + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + + final LoadbalanceTarget target1 = LoadbalanceTarget.from("1", mockTransport1); + final LoadbalanceTarget target2 = LoadbalanceTarget.from("2", mockTransport2); + + final WeightedRSocket weightedRSocket1 = new WeightedRSocket(counter); + final WeightedRSocket weightedRSocket2 = new WeightedRSocket(counter); + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + Mockito.when(rSocketConnectorMock.connect(mockTransport1)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket1))); + Mockito.when(rSocketConnectorMock.connect(mockTransport2)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket2))); + final List collectionOfDestination1 = Collections.singletonList(target1); + final List collectionOfDestination2 = Collections.singletonList(target2); + final List collectionOfDestinations1And2 = Arrays.asList(target1, target2); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source.asFlux(), + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver( + rsocket -> { + if (rsocket instanceof TestRSocket) { + return (WeightedRSocket) ((TestRSocket) rsocket).source(); + } + return ((PooledRSocket) rsocket).target() == target1 + ? weightedRSocket1 + : weightedRSocket2; + }) + .build()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void ensureRSocketIsCleanedFromThePoolIfSourceRSocketIsDisposed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + final TestRSocket testRSocket = + new TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.delay(Duration.ofMillis(200)).map(__ -> testRSocket)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + testRSocket.dispose(); + + Assertions.assertThatThrownBy( + () -> + rSocketPool + .select() + .fireAndForget(EmptyPayload.INSTANCE) + .block(Duration.ofSeconds(2))) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on blocking read for 2000000000 NANOSECONDS"); + + Assertions.assertThat(counter.get()).isOne(); + } + + @Test + public void ensureContextIsPropagatedCorrectlyForRequestChannel() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.delay(Duration.ofMillis(200)) + .map( + __ -> + new TestRSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher source) { + counter.incrementAndGet(); + return Flux.from(source); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + // check that context is propagated when there is no rsocket + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .then( + () -> + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)))) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport))); + // check that context is propagated when there is an RSocket but it is unresolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + // check that context is propagated when there is an RSocket and it is resolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + Assertions.assertThat(counter.get()).isEqualTo(3); + } + + @Test + public void shouldNotifyOnCloseWhenAllTheActiveSubscribersAreClosed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Sinks.Empty onCloseSocket1 = Sinks.empty(); + Sinks.Empty onCloseSocket2 = Sinks.empty(); + + RSocket socket1 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket1.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + RSocket socket2 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket2.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(socket1)) + .then(im -> Mono.just(socket2)) + .then(im -> Mono.never().doOnCancel(() -> counter.incrementAndGet())); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport), + LoadbalanceTarget.from("3", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + + rSocketPool.dispose(); + + AssertSubscriber onCloseSubscriber = + rSocketPool.onClose().subscribeWith(AssertSubscriber.create()); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket1.tryEmitEmpty(); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket2.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(counter.get()).isOne(); + } + + static class TestRSocket extends RSocketProxy { + + final Sinks.Empty sink = Sinks.empty(); + + public TestRSocket(RSocket rSocket) { + super(rSocket); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return source; + } + } + + private static class WeightedRSocket extends BaseWeightedStats implements RSocket { + + private final AtomicInteger counter; + + public WeightedRSocket(AtomicInteger counter) { + this.counter = counter; + } + + @Override + public Mono fireAndForget(Payload payload) { + final long startTime = startRequest(); + counter.incrementAndGet(); + return Mono.empty() + .doFinally( + (__) -> { + final long stopTime = stopRequest(startTime); + record(stopTime - startTime); + }); + } + } + + static class RetrySubscriber implements CoreSubscriber { + + final Publisher source; + + private RetrySubscriber(Publisher source) { + this.source = source; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void unused) {} + + @Override + public void onError(Throwable t) { + source.subscribe(this); + } + + @Override + public void onComplete() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java new file mode 100644 index 000000000..e43068dbd --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java @@ -0,0 +1,170 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.test.publisher.TestPublisher; + +public class RoundRobinLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverValuesProportionally() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()).isCloseTo(500, Offset.offset(1)); + Assertions.assertThat(counter2.get()).isCloseTo(500, Offset.offset(1)); + } + + @Test + public void shouldDeliverValuesToNewlyConnectedSockets() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + if (im.getArgument(0) == mockTransport1) { + counter1.incrementAndGet(); + } else { + counter2.incrementAndGet(); + } + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 3, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java new file mode 100644 index 000000000..8cc254cbb --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java @@ -0,0 +1,254 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.Clock; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.publisher.TestPublisher; + +public class WeightedLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void allRequestsShouldGoToTheSocketWithHigherWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + Assertions.assertThat(counter2.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + } + + @Test + public void shouldDeliverValuesToTheSocketWithTheHighestCalculatedWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket3 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)) + .then(im -> Mono.just(rSocket3)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + rSocket1.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + + rSocket2.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 2, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + } + + static class WeightedTestRSocket extends BaseWeightedStats implements RSocket { + + final Sinks.Empty sink = Sinks.empty(); + + final RSocket rSocket; + + public WeightedTestRSocket(RSocket rSocket) { + this.rSocket = rSocket; + } + + @Override + public Mono fireAndForget(Payload payload) { + startRequest(); + final long startTime = Clock.now(); + return this.rSocket + .fireAndForget(payload) + .doFinally( + __ -> { + stopRequest(startTime); + record(Clock.now() - startTime); + updateAvailability(1.0); + }); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return rSocket; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java new file mode 100644 index 000000000..58ab30021 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java @@ -0,0 +1,474 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class AuthMetadataCodecTest { + + public static final int AUTH_TYPE_ID_LENGTH = 1; + public static final int USER_NAME_BYTES_LENGTH = 2; + public static final String TEST_BEARER_TOKEN = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpYXQxIjoxNTE2MjM5MDIyLCJpYXQyIjoxNTE2MjM5MDIyLCJpYXQzIjoxNTE2MjM5MDIyLCJpYXQ0IjoxNTE2MjM5MDIyfQ.ljYuH-GNyyhhLcx-rHMchRkGbNsR2_4aSxo8XjrYrSM"; + + @Test + void shouldCorrectlyEncodeData() { + String username = "test"; + String password = "tset1234"; + + int usernameLength = username.length(); + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData1() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData2() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎1234567#4? "; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + private static void checkSimpleAuthMetadataEncoding( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(byteBuf.readUnsignedByte() & ~0x80) + .isEqualTo(WellKnownAuthType.SIMPLE.getIdentifier()); + Assertions.assertThat(byteBuf.readUnsignedShort()).isEqualTo((short) usernameLength); + + Assertions.assertThat(byteBuf.readCharSequence(usernameLength, CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(byteBuf.readCharSequence(passwordLength, CharsetUtil.UTF_8)) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + private static void checkSimpleAuthMetadataEncodingUsingDecoders( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(AuthMetadataCodec.readWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.SIMPLE); + byteBuf.markReaderIndex(); + Assertions.assertThat(AuthMetadataCodec.readUsername(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(AuthMetadataCodec.readPassword(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(password); + byteBuf.resetReaderIndex(); + + Assertions.assertThat(new String(AuthMetadataCodec.readUsernameAsCharArray(byteBuf))) + .isEqualTo(username); + Assertions.assertThat(new String(AuthMetadataCodec.readPasswordAsCharArray(byteBuf))) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + @Test + void shouldThrowExceptionIfUsernameLengthExitsAllowedBounds() { + StringBuilder usernameBuilder = new StringBuilder(); + String usernamePart = + "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎𠸏𠹷𠺝𠺢𠻗𠻹𠻺𠼭𠼮𠽌𠾴𠾼𠿪𡁜𡁯𡁵𡁶𡁻𡃁𡃉𡇙𢃇𢞵𢫕𢭃𢯊𢱑𢱕𢳂𢴈𢵌𢵧𢺳𣲷𤓓𤶸𤷪𥄫𦉘𦟌𦧲𦧺𧨾𨅝𨈇𨋢𨳊𨳍𨳒𩶘𠜎𠜱𠝹"; + for (int i = 0; i < 65535 / usernamePart.length(); i++) { + usernameBuilder.append(usernamePart); + } + String password = "tset1234"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, + usernameBuilder.toString().toCharArray(), + password.toCharArray())) + .hasMessage( + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); + } + + @Test + void shouldEncodeBearerMetadata() { + String testToken = TEST_BEARER_TOKEN; + + ByteBuf byteBuf = + AuthMetadataCodec.encodeBearerMetadata(ByteBufAllocator.DEFAULT, testToken.toCharArray()); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(testToken, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(testToken, byteBuf); + } + + private static void checkBearerAuthMetadataEncoding(String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat( + byteBuf.readUnsignedByte() & ~AuthMetadataCodec.STREAM_METADATA_KNOWN_MASK) + .isEqualTo(WellKnownAuthType.BEARER.getIdentifier()); + Assertions.assertThat(byteBuf.readSlice(byteBuf.capacity() - 1).toString(CharsetUtil.UTF_8)) + .isEqualTo(testToken); + } + + private static void checkBearerAuthMetadataEncodingUsingDecoders( + String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(byteBuf)).isTrue(); + Assertions.assertThat(AuthMetadataCodec.readWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.BEARER); + byteBuf.markReaderIndex(); + Assertions.assertThat(new String(AuthMetadataCodec.readBearerTokenAsCharArray(byteBuf))) + .isEqualTo(testToken); + byteBuf.resetReaderIndex(); + Assertions.assertThat( + AuthMetadataCodec.readPayload(byteBuf).toString(CharsetUtil.UTF_8).toString()) + .isEqualTo(testToken); + } + + @Test + void shouldEncodeCustomAuth() { + String payloadAsAText = "testsecuritybuffer"; + ByteBuf testSecurityPayload = + Unpooled.wrappedBuffer(payloadAsAText.getBytes(CharsetUtil.UTF_8)); + + String customAuthType = "myownauthtype"; + ByteBuf buffer = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload); + + checkCustomAuthMetadataEncoding(testSecurityPayload, customAuthType, buffer); + } + + private static void checkCustomAuthMetadataEncoding( + ByteBuf testSecurityPayload, String customAuthType, ByteBuf buffer) { + Assertions.assertThat(buffer.capacity()) + .isEqualTo(1 + customAuthType.length() + testSecurityPayload.capacity()); + Assertions.assertThat(buffer.readUnsignedByte()) + .isEqualTo((short) (customAuthType.length() - 1)); + Assertions.assertThat( + buffer.readCharSequence(customAuthType.length(), CharsetUtil.US_ASCII).toString()) + .isEqualTo(customAuthType); + Assertions.assertThat(buffer.readSlice(testSecurityPayload.capacity())) + .isEqualTo(testSecurityPayload); + + ReferenceCountUtil.release(buffer); + } + + @Test + void shouldThrowOnNonASCIIChars() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = "1234567#4? 𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage("custom auth type must be US_ASCII characters only"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + // 130 chars + String customAuthType = + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType1() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = ""; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldEncodeUsingWellKnownAuthType() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType1() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType2() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldThrowIfWellKnownAuthTypeIsUnsupportedOrUnknown() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + buffer.release(); + } + + @Test + void shouldCompressMetadata() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "simple", + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldCompressMetadata1() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "bearer", + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldNotCompressMetadata() { + ByteBuf testMetadataPayload = + Unpooled.wrappedBuffer(TEST_BEARER_TOKEN.getBytes(CharsetUtil.UTF_8)); + String customAuthType = "testauthtype"; + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, customAuthType, testMetadataPayload); + + checkCustomAuthMetadataEncoding(testMetadataPayload, customAuthType, byteBuf); + } + + @Test + void shouldConfirmWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(metadata)).isTrue(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldConfirmGivenMetadataIsNotAWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple/afafgafadf", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(metadata)).isFalse(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadSimpleWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.SIMPLE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType1() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "bearer", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.BEARER; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType2() { + ByteBuf metadata = + ByteBufAllocator.DEFAULT + .buffer() + .writeByte(3 | AuthMetadataCodec.STREAM_METADATA_KNOWN_MASK); + WellKnownAuthType expectedType = WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength() { + ByteBuf metadata = ByteBufAllocator.DEFAULT.buffer().writeByte(3); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength1() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, "testmetadataauthtype", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldThrowExceptionIsNotEnoughReadableBytes() { + Assertions.assertThatThrownBy( + () -> AuthMetadataCodec.readWellKnownAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode Well Know Auth type. Not enough readable bytes"); + } + + private static void checkDecodeWellKnowAuthTypeCorrectly( + ByteBuf metadata, WellKnownAuthType expectedType) { + int initialReaderIndex = metadata.readerIndex(); + + WellKnownAuthType wellKnownAuthType = AuthMetadataCodec.readWellKnownAuthType(metadata); + + Assertions.assertThat(wellKnownAuthType).isEqualTo(expectedType); + Assertions.assertThat(metadata.readerIndex()) + .isNotEqualTo(initialReaderIndex) + .isEqualTo(initialReaderIndex + 1); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadCustomEncodedAuthType() { + String testAuthType = "TestAuthType"; + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, testAuthType, Unpooled.EMPTY_BUFFER); + checkDecodeCustomAuthTypeCorrectly(testAuthType, byteBuf); + } + + @Test + void shouldThrowExceptionOnEmptyMetadata() { + Assertions.assertThatThrownBy(() -> AuthMetadataCodec.readCustomAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode custom Auth type. Not enough readable bytes"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_wellknowninstead() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.readCustomAuthType( + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(new byte[] {'a', 'b'})))) + .hasMessage("Unable to decode custom Auth type. Incorrect auth type length"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_length() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.readCustomAuthType( + ByteBufAllocator.DEFAULT.buffer().writeByte(127).writeChar('a').writeChar('b'))) + .hasMessage("Unable to decode custom Auth type. Malformed length or auth type string"); + } + + private static void checkDecodeCustomAuthTypeCorrectly(String testAuthType, ByteBuf byteBuf) { + int initialReaderIndex = byteBuf.readerIndex(); + + Assertions.assertThat(AuthMetadataCodec.readCustomAuthType(byteBuf).toString()) + .isEqualTo(testAuthType); + Assertions.assertThat(byteBuf.readerIndex()) + .isEqualTo(initialReaderIndex + testAuthType.length() + 1); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java new file mode 100644 index 000000000..a4e8fb2d8 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java @@ -0,0 +1,558 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer; +import static org.assertj.core.api.Assertions.*; + +import io.netty.buffer.*; +import io.netty.util.CharsetUtil; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +class CompositeMetadataCodecTest { + + final LeaksTrackingByteBufAllocator testAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @AfterEach + void tearDownAndCheckForLeaks() { + testAllocator.assertHasNoLeaks(); + } + + static String byteToBitsString(byte b) { + return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0'); + } + + static String toHeaderBits(ByteBuf encoded) { + encoded.markReaderIndex(); + byte headerByte = encoded.readByte(); + String byteAsString = byteToBitsString(headerByte); + encoded.resetReaderIndex(); + return byteAsString; + } + // ==== + + @Test + void customMimeHeaderLatin1_encodingFails() { + String mimeNotAscii = "mime/typé"; + + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void customMimeHeaderLength0_encodingFails() { + assertThatIllegalArgumentException() + .isThrownBy(() -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "", 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLength127() { + StringBuilder builder = new StringBuilder(127); + for (int i = 0; i < 127; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111110"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(127 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(127, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderLength128() { + StringBuilder builder = new StringBuilder(128); + for (int i = 0; i < 128; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111111"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(128 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(128, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderLength129_encodingFails() { + StringBuilder builder = new StringBuilder(129); + for (int i = 0; i < 129; i++) { + builder.append('a'); + } + + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, builder.toString(), 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLengthOne() { + String mimeString = "w"; + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()).as("mime length").isZero(); // encoded as actual length - 1 + + assertThat(header.readCharSequence(1, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderLengthTwo() { + String mimeString = "ww"; + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000001"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(2 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(2, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderUtf8_encodingFails() { + String mimeNotAscii = + "mime/tyࠒe"; // this is the SAMARITAN LETTER QUF u+0812 represented on 3 bytes + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void decodeEntryAtEndOfBuffer() { + ByteBuf fakeEntry = Unpooled.buffer(); + + assertThatIllegalArgumentException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeIdMinusTwoWhenMoreThanOneByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(2); + fakeIdBuffer.writeInt(200); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeIdMinusTwoWhenZeroByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(0); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeStringNullIfLengthOne() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(1); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeStringNullIfLengthZero() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeTypeSkipsFirstByte() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(128); + fakeTypeBuffer.writeCharSequence("example", CharsetUtil.US_ASCII); + + assertThat(decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)).hasToString("example"); + } + + @Test + void encodeMetadataCustomTypeDelegates() { + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "foo", 2); + + CompositeByteBuf test = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + test, testAllocator, "foo", ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); + } + + @Test + void encodeMetadataKnownTypeDelegates() { + ByteBuf expected = + CompositeMetadataCodec.encodeMetadataHeader( + testAllocator, WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), 2); + + CompositeByteBuf test = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + test, + testAllocator, + WellKnownMimeType.APPLICATION_OCTET_STREAM, + ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); + } + + @Test + void encodeMetadataReservedTypeDelegates() { + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, (byte) 120, 2); + + CompositeByteBuf test = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + test, testAllocator, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); + } + + @Test + void encodeTryCompressWithCompressableType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + target, testAllocator, WellKnownMimeType.APPLICATION_AVRO.getString(), metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 3 + 2").isEqualTo(6); + target.release(); + } + + @Test + void encodeTryCompressWithCustomType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + target, testAllocator, "custom/example", metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 14 + 3 + 2").isEqualTo(20); + target.release(); + } + + @Test + void hasEntry() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + + CompositeByteBuf buffer = + testAllocator + .compositeBuffer() + .addComponent( + true, + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0)) + .addComponent( + true, + CompositeMetadataCodec.encodeMetadataHeader( + testAllocator, mime.getIdentifier(), 0)); + + assertThat(CompositeMetadataCodec.hasEntry(buffer, 0)).isTrue(); + assertThat(CompositeMetadataCodec.hasEntry(buffer, 4)).isTrue(); + assertThat(CompositeMetadataCodec.hasEntry(buffer, 8)).isFalse(); + buffer.release(); + } + + @Test + void isWellKnownMimeType() { + ByteBuf wellKnown = Unpooled.buffer().writeByte(0); + assertThat(CompositeMetadataCodec.isWellKnownMimeType(wellKnown)).isTrue(); + + ByteBuf explicit = Unpooled.buffer().writeByte(2).writeChar('a'); + assertThat(CompositeMetadataCodec.isWellKnownMimeType(explicit)).isFalse(); + } + + @Test + void knownMimeHeader120_reserved() { + byte mime = (byte) 120; + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime, 0); + + assertThat(mime) + .as("smoke test RESERVED_120 unsigned 7 bits representation") + .isEqualTo((byte) 0b01111000); + + assertThat(toHeaderBits(encoded)).startsWith("1").isEqualTo("11111000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)).as("decoded mime id").isEqualTo(mime); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void knownMimeHeader127_compositeMetadata() { + WellKnownMimeType mime = WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA; + assertThat(mime.getIdentifier()) + .as("smoke test COMPOSITE unsigned 7 bits representation") + .isEqualTo((byte) 127) + .isEqualTo((byte) 0b01111111); + ByteBuf encoded = + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("11111111") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111111"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void knownMimeHeaderZero_avro() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + assertThat(mime.getIdentifier()) + .as("smoke test AVRO unsigned 7 bits representation") + .isEqualTo((byte) 0) + .isEqualTo((byte) 0b00000000); + ByteBuf encoded = + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("10000000") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("10000000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void encodeCustomHeaderAsciiCheckSkipsFirstByte() { + final ByteBuf badBuf = Unpooled.copiedBuffer("é00000000000", CharsetUtil.UTF_8); + badBuf.writerIndex(0); + assertThat(badBuf.readerIndex()).isZero(); + + ByteBufAllocator allocator = + new AbstractByteBufAllocator() { + @Override + public boolean isDirectBufferPooled() { + return false; + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + return badBuf; + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + return badBuf; + } + }; + + assertThatCode(() -> CompositeMetadataCodec.encodeMetadataHeader(allocator, "custom/type", 0)) + .doesNotThrowAnyException(); + + assertThat(badBuf.readByte()).isEqualTo((byte) 10); + assertThat(badBuf.readCharSequence(11, CharsetUtil.UTF_8)).hasToString("custom/type"); + assertThat(badBuf.readUnsignedMedium()).isEqualTo(0); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java new file mode 100644 index 000000000..0b81ab4b0 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java @@ -0,0 +1,178 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadata.ReservedMimeTypeEntry; +import io.rsocket.metadata.CompositeMetadata.WellKnownMimeTypeEntry; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import java.util.Iterator; +import java.util.Spliterator; +import org.junit.jupiter.api.Test; + +class CompositeMetadataTest { + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryOnDoneBufferThrowsIllegalArgument() { + ByteBuf fakeBuffer = ByteBufUtils.getRandomByteBuf(0); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeBuffer, false); + + assertThatIllegalArgumentException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("entry index 0 is larger than buffer size"); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeThreeEntries() { + // metadata 1: well known + WellKnownMimeType mimeType1 = WellKnownMimeType.APPLICATION_PDF; + ByteBuf metadata1 = Unpooled.buffer(); + metadata1.writeCharSequence("abcdefghijkl", CharsetUtil.UTF_8); + + // metadata 2: custom + String mimeType2 = "application/custom"; + ByteBuf metadata2 = Unpooled.buffer(); + metadata2.writeChar('E'); + metadata2.writeChar('∑'); + metadata2.writeChar('é'); + metadata2.writeBoolean(true); + metadata2.writeChar('W'); + + // metadata 3: reserved but unknown + byte reserved = 120; + assertThat(WellKnownMimeType.fromIdentifier(reserved)) + .as("ensure UNKNOWN RESERVED used in test") + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + ByteBuf metadata3 = Unpooled.buffer(); + metadata3.writeByte(88); + + CompositeByteBuf compositeMetadataBuffer = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType1, metadata1); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType2, metadata2); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, reserved, metadata3); + + Iterator iterator = new CompositeMetadata(compositeMetadataBuffer, true).iterator(); + + assertThat(iterator.next()) + .as("entry1") + .isNotNull() + .satisfies( + e -> + assertThat(e.getMimeType()).as("entry1 mime type").isEqualTo(mimeType1.getString())) + .satisfies( + e -> + assertThat(((WellKnownMimeTypeEntry) e).getType()) + .as("entry1 mime id") + .isEqualTo(WellKnownMimeType.APPLICATION_PDF)) + .satisfies( + e -> + assertThat(e.getContent().toString(CharsetUtil.UTF_8)) + .as("entry1 decoded") + .isEqualTo("abcdefghijkl")); + + assertThat(iterator.next()) + .as("entry2") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry2 mime type").isEqualTo(mimeType2)) + .satisfies( + e -> assertThat(e.getContent()).as("entry2 decoded").isEqualByComparingTo(metadata2)); + + assertThat(iterator.next()) + .as("entry3") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry3 mime type").isNull()) + .satisfies( + e -> + assertThat(((ReservedMimeTypeEntry) e).getType()) + .as("entry3 mime id") + .isEqualTo(reserved)) + .satisfies( + e -> assertThat(e.getContent()).as("entry3 decoded").isEqualByComparingTo(metadata3)); + + assertThat(iterator.hasNext()).as("has no more than 3 entries").isFalse(); + } + + @Test + void streamIsNotParallel() { + final CompositeMetadata metadata = + new CompositeMetadata(ByteBufUtils.getRandomByteBuf(5), false); + + assertThat(metadata.stream().isParallel()).as("isParallel").isFalse(); + } + + @Test + void streamSpliteratorCharacteristics() { + final CompositeMetadata metadata = + new CompositeMetadata(ByteBufUtils.getRandomByteBuf(5), false); + + assertThat(metadata.stream().spliterator()) + .matches(s -> s.hasCharacteristics(Spliterator.ORDERED), "ORDERED") + .matches(s -> s.hasCharacteristics(Spliterator.DISTINCT), "DISTINCT") + .matches(s -> s.hasCharacteristics(Spliterator.NONNULL), "NONNULL") + .matches(s -> !s.hasCharacteristics(Spliterator.SIZED), "not SIZED"); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java new file mode 100644 index 000000000..5c8d40306 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.util.List; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link MimeTypeMetadataCodec}. */ +public class MimeTypeMetadataCodecTest { + + @Test + public void wellKnownMimeType() { + WellKnownMimeType mimeType = WellKnownMimeType.APPLICATION_HESSIAN; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(WellKnownMimeType.fromString(mimeTypes.get(0))).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void customMimeType() { + String mimeType = "aaa/bb"; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(mimeTypes.get(0)).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void multipleMimeTypes() { + List mimeTypes = Lists.newArrayList("aaa/bbb", "application/x-hessian"); + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeTypes); + + try { + assertThat(MimeTypeMetadataCodec.decode(byteBuf)).isEqualTo(mimeTypes); + } finally { + byteBuf.release(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java new file mode 100644 index 000000000..b65ffafee --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java @@ -0,0 +1,47 @@ +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBufAllocator; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +/** + * Tagging metadata test + * + * @author linux_china + */ +public class TaggingMetadataTest { + private ByteBufAllocator byteBufAllocator = ByteBufAllocator.DEFAULT; + + @Test + public void testParseTags() { + List tags = + Arrays.asList( + "ws://localhost:8080/rsocket", String.join("", Collections.nCopies(129, "x"))); + TaggingMetadata taggingMetadata = + TaggingMetadataCodec.createTaggingMetadata( + byteBufAllocator, "message/x.rsocket.routing.v0", tags); + TaggingMetadata taggingMetadataCopy = + new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); + assertThat(tags) + .containsExactlyElementsOf(taggingMetadataCopy.stream().collect(Collectors.toList())); + } + + @Test + public void testEmptyTagAndOverLengthTag() { + List tags = + Arrays.asList( + "ws://localhost:8080/rsocket", "", String.join("", Collections.nCopies(256, "x"))); + TaggingMetadata taggingMetadata = + TaggingMetadataCodec.createTaggingMetadata( + byteBufAllocator, "message/x.rsocket.routing.v0", tags); + TaggingMetadata taggingMetadataCopy = + new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); + assertThat(tags.subList(0, 1)) + .containsExactlyElementsOf(taggingMetadataCopy.stream().collect(Collectors.toList())); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java new file mode 100644 index 000000000..cb8478c13 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java @@ -0,0 +1,209 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCounted; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +public class TracingMetadataCodecTest { + + private static Stream flags() { + return Stream.of(TracingMetadataCodec.Flags.values()); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeEmptyTrace(TracingMetadataCodec.Flags expectedFlag) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encodeEmpty(allocator, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(TracingMetadata::isEmpty) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode64(allocator, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encode64(allocator, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128( + allocator, traceIdHigh, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128(allocator, traceIdHigh, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java new file mode 100644 index 000000000..316aaf091 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class WellKnownMimeTypeTest { + + @Test + void fromIdentifierGreaterThan127() { + assertThat(WellKnownMimeType.fromIdentifier(128)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierMatchFromMimeType() { + for (WellKnownMimeType mimeType : WellKnownMimeType.values()) { + if (mimeType == WellKnownMimeType.UNPARSEABLE_MIME_TYPE + || mimeType == WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE) { + continue; + } + assertThat(WellKnownMimeType.fromString(mimeType.toString())) + .as("mimeType string for " + mimeType.name()) + .isSameAs(mimeType); + + assertThat(WellKnownMimeType.fromIdentifier(mimeType.getIdentifier())) + .as("mimeType ID for " + mimeType.name()) + .isSameAs(mimeType); + } + } + + @Test + void fromIdentifierNegative() { + assertThat(WellKnownMimeType.fromIdentifier(-1)) + .isSameAs(WellKnownMimeType.fromIdentifier(-2)) + .isSameAs(WellKnownMimeType.fromIdentifier(-12)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierReserved() { + assertThat(WellKnownMimeType.fromIdentifier(120)) + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + } + + @Test + void fromStringUnknown() { + assertThat(WellKnownMimeType.fromString("foo/bar")) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromStringUnknownReservedStillReturnsUnparseable() { + assertThat( + WellKnownMimeType.fromString(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE.getString())) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java new file mode 100644 index 000000000..9a19050f9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java @@ -0,0 +1,790 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.FrameType; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.annotation.Nullable; + +public class RequestInterceptorTest { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @Test + void ensuresExceptionInTheInterceptorIsHandledProperly() { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnCancel"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + StepVerifier.create(rSocket.fireAndForget(DefaultPayload.create("test"))) + .expectSubscription() + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestStream(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestChannel(Flux.just(DefaultPayload.create("test")))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldSupportMultipleInterceptors(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor1 = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequestInterceptor testRequestInterceptor2 = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor1) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor2)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + testRequestInterceptor2 + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java new file mode 100644 index 000000000..8261b3f49 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -0,0 +1,142 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Queue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Condition; +import reactor.util.annotation.Nullable; + +public class TestRequestInterceptor implements RequestInterceptor { + + final Queue events = new MpscUnboundedArrayQueue<>(128); + + @Override + public void dispose() {} + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_START, streamId, requestType, null)); + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + events.add( + new Event( + t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, requestType, t)); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + events.add(new Event(EventType.ON_CANCEL, streamId, requestType, null)); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_REJECT, -1, requestType, rejectionReason)); + } + + public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_START) + .hasFieldOrPropertyWithValue("streamId", streamId) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectOnComplete(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_COMPLETE) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnError(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_ERROR) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnCancel(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_CANCEL) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor assertNext(Consumer consumer) { + final Event event = events.poll(); + Assertions.assertThat(event).isNotNull(); + + consumer.accept(event); + + return this; + } + + public TestRequestInterceptor expectOnReject(FrameType requestType, Throwable rejectionReason) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_REJECT) + .has( + new Condition<>( + e -> { + Assertions.assertThat(e.error) + .isExactlyInstanceOf(rejectionReason.getClass()) + .hasMessage(rejectionReason.getMessage()) + .hasCause(rejectionReason.getCause()); + return true; + }, + "Has rejection reason which matches to %s", + rejectionReason)) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectNothing() { + final Event event = events.poll(); + + Assertions.assertThat(event).isNull(); + + return this; + } + + public static final class Event { + public final EventType eventType; + public final int streamId; + public final FrameType requestType; + public final Throwable error; + + Event(EventType eventType, int streamId, FrameType requestType, Throwable error) { + this.eventType = eventType; + this.streamId = streamId; + this.requestType = requestType; + this.error = error; + } + } + + public enum EventType { + ON_START, + ON_COMPLETE, + ON_ERROR, + ON_CANCEL, + ON_REJECT + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java new file mode 100644 index 000000000..8229bf42b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java @@ -0,0 +1,470 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.exceptions.ConnectionCloseException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ClientRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new ConnectionCloseException("some message"))); + // connection should be closed because of the wrong first frame + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout is still in progress + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void sessionTerminationOnWrongFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // Send KEEPALIVE frame as a first frame + transport + .testConnection() + .addToReceivedBuffer( + KeepAliveFrameCodec.encode(transport.alloc(), false, 0, Unpooled.EMPTY_BUFFER)); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectErrorMessage("RESUME_OK frame must be received before any others") + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldErrorWithNoRetriesOnErrorFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send REJECTED_RESUME_ERROR frame + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new RejectedResumeException("failed resumption"))); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectError(RejectedResumeException.class) + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java new file mode 100644 index 000000000..bba40d674 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java @@ -0,0 +1,547 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.util.Arrays; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Disposable; +import reactor.core.publisher.Hooks; +import reactor.test.util.RaceTestUtils; + +public class InMemoryResumeStoreTest { + + @Test + void saveNonResumableFrame() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeConnectionFrame(10); + final ByteBuf frame2 = fakeConnectionFrame(35); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + } + + @Test + void saveWithoutTailRemoval() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame = fakeResumableFrame(10); + + sender.onNext(frame); + + assertThat(store.cachedFrames.size()).isEqualTo(1); + assertThat(store.cacheSize).isEqualTo(frame.readableBytes()); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(1).values().forEach(ByteBuf::release); + + assertThat(frame.refCnt()).isOne(); + } + + @Test + void saveRemoveOneFromTail() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(20); + final ByteBuf frame2 = fakeResumableFrame(10); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame2.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(frame1.readableBytes()); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isOne(); + } + + @Test + void saveRemoveTwoFromTail() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(20); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); + } + + @Test + void saveBiggerThanStore() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2, frame3)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @Test + void releaseFrames() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + + store.releaseFrames(20); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); + } + + @Test + void receiveImpliedPosition() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + ByteBuf frame1 = fakeResumableFrame(10); + ByteBuf frame2 = fakeResumableFrame(30); + + store.resumableFrameReceived(frame1); + store.resumableFrameReceived(frame2); + + assertThat(store.frameImpliedPosition()).isEqualTo(size(frame1, frame2)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void ensuresCleansOnTerminal(boolean hasSubscriber) { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + hasSubscriber ? store.resumeStream().subscribeWith(AssertSubscriber.create()) : null; + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + if (hasSubscriber) { + assertSubscriber.assertValueCount(3).assertTerminated().values().forEach(ByteBuf::release); + } + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @Test + void ensuresCleansOnTerminalLateSubscriber() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + assertSubscriber.assertTerminated(); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @ParameterizedTest(name = "Sending vs Reconnect Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void sendingVsReconnectRaceTest(boolean withLateSubscriber) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + final BlockingQueue receivedFrames = new ArrayBlockingQueue<>(10); + final AtomicInteger receivedPosition = new AtomicInteger(); + + store.saveFrames(frames).subscribe(); + + final Consumer consumer = + f -> { + if (ResumableDuplexConnection.isResumableFrame(f)) { + receivedPosition.addAndGet(f.readableBytes()); + receivedFrames.offer(f); + return; + } + f.release(); + }; + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber ? null : store.resumeStream().subscribe(consumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer)); + } + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + store.releaseFrames(receivedFrames.stream().mapToInt(ByteBuf::readableBytes).sum()); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + + assertThat(receivedFrames) + .hasSize(5) + .containsSequence(byteBuf1, byteBuf2, byteBuf3, byteBuf4, byteBuf5); + receivedFrames.forEach(ReferenceCounted::release); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } + + @ParameterizedTest( + name = "Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValue(1); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest( + name = + "Dispose vs Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextWithDisposeRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }, + store::dispose); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValueGreaterThanOrEqualTo(1).hasValueLessThanOrEqualTo(2); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + private int size(ByteBuf... byteBufs) { + return Arrays.stream(byteBufs).mapToInt(ByteBuf::readableBytes).sum(); + } + + private static InMemoryResumableFramesStore inMemoryStore(int size) { + return new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, size); + } + + private static ByteBuf fakeResumableFrame(int size) { + byte[] bytes = new byte[size]; + Arrays.fill(bytes, (byte) 7); + return Unpooled.wrappedBuffer(bytes); + } + + private static ByteBuf fakeConnectionFrame(int size) { + byte[] bytes = new byte[size]; + Arrays.fill(bytes, (byte) 0); + return Unpooled.wrappedBuffer(bytes); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCacheTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeCacheTest.java deleted file mode 100644 index 33bbde9b8..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeCacheTest.java +++ /dev/null @@ -1,122 +0,0 @@ -package io.rsocket.resume; - -import static org.junit.Assert.assertEquals; - -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.util.PayloadImpl; -import org.junit.Test; -import reactor.core.publisher.Flux; - -public class ResumeCacheTest { - private Frame CANCEL = Frame.Cancel.from(1); - private Frame STREAM = - Frame.Request.from(1, FrameType.REQUEST_STREAM, new PayloadImpl("Test"), 100); - - private ResumeCache cache = new ResumeCache(ResumePositionCounter.frames(), 2); - - @Test - public void startsEmpty() { - Flux x = cache.resend(0); - assertEquals(0L, (long) x.count().block()); - cache.updateRemotePosition(0); - } - - @Test(expected = IllegalStateException.class) - public void failsForFutureUpdatePosition() { - cache.updateRemotePosition(1); - } - - @Test(expected = IllegalStateException.class) - public void failsForFutureResend() { - cache.resend(1); - } - - @Test - public void updatesPositions() { - assertEquals(0, cache.getRemotePosition()); - assertEquals(0, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(STREAM); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(1, cache.size()); - - cache.updateRemotePosition(14); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(CANCEL); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(20, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(1, cache.size()); - - cache.updateRemotePosition(20); - - assertEquals(20, cache.getRemotePosition()); - assertEquals(20, cache.getCurrentPosition()); - assertEquals(20, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(STREAM); - - assertEquals(20, cache.getRemotePosition()); - assertEquals(34, cache.getCurrentPosition()); - assertEquals(20, cache.getEarliestResendPosition()); - assertEquals(1, cache.size()); - } - - @Test - public void supportsZeroBuffer() { - cache = new ResumeCache(ResumePositionCounter.frames(), 0); - - cache.sent(STREAM); - cache.sent(STREAM); - cache.sent(STREAM); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(42, cache.getCurrentPosition()); - assertEquals(42, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - } - - @Test - public void supportsFrameCountBuffers() { - cache = new ResumeCache(ResumePositionCounter.size(), 100); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(0, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(STREAM); - - assertEquals(0, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(0, cache.getEarliestResendPosition()); - assertEquals(14, cache.size()); - - cache.updateRemotePosition(14); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(14, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(0, cache.size()); - - cache.sent(CANCEL); - - assertEquals(14, cache.getRemotePosition()); - assertEquals(20, cache.getCurrentPosition()); - assertEquals(14, cache.getEarliestResendPosition()); - assertEquals(6, cache.size()); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeTokenTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeTokenTest.java deleted file mode 100644 index 122fad857..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeTokenTest.java +++ /dev/null @@ -1,21 +0,0 @@ -package io.rsocket.resume; - -import static org.junit.Assert.assertEquals; - -import java.util.UUID; -import org.junit.Test; - -public class ResumeTokenTest { - @Test - public void testFromUuid() { - UUID x = UUID.fromString("3bac9870-3873-403a-99f4-9728aa8c7860"); - - ResumeToken t = ResumeToken.bytes(ResumeToken.getBytesFromUUID(x)); - ResumeToken t2 = ResumeToken.bytes(ResumeToken.getBytesFromUUID(x)); - - assertEquals("3bac98703873403a99f49728aa8c7860", t.toString()); - - assertEquals(t.hashCode(), t2.hashCode()); - assertEquals(t, t2); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ResumeUtilTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ResumeUtilTest.java deleted file mode 100644 index 68f64c1ba..000000000 --- a/rsocket-core/src/test/java/io/rsocket/resume/ResumeUtilTest.java +++ /dev/null @@ -1,44 +0,0 @@ -package io.rsocket.resume; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.util.PayloadImpl; -import org.junit.Test; - -public class ResumeUtilTest { - private Frame CANCEL = Frame.Cancel.from(1); - private Frame STREAM = - Frame.Request.from(1, FrameType.REQUEST_STREAM, new PayloadImpl("Test"), 100); - - @Test - public void testSupportedTypes() { - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_STREAM)); - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_CHANNEL)); - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_RESPONSE)); - assertTrue(ResumeUtil.isTracked(FrameType.REQUEST_N)); - assertTrue(ResumeUtil.isTracked(FrameType.CANCEL)); - assertTrue(ResumeUtil.isTracked(FrameType.ERROR)); - assertTrue(ResumeUtil.isTracked(FrameType.FIRE_AND_FORGET)); - assertTrue(ResumeUtil.isTracked(FrameType.PAYLOAD)); - } - - @Test - public void testUnsupportedTypes() { - assertFalse(ResumeUtil.isTracked(FrameType.METADATA_PUSH)); - assertFalse(ResumeUtil.isTracked(FrameType.RESUME)); - assertFalse(ResumeUtil.isTracked(FrameType.RESUME_OK)); - assertFalse(ResumeUtil.isTracked(FrameType.SETUP)); - assertFalse(ResumeUtil.isTracked(FrameType.EXT)); - assertFalse(ResumeUtil.isTracked(FrameType.KEEPALIVE)); - } - - @Test - public void testOffset() { - assertEquals(6, ResumeUtil.offset(CANCEL)); - assertEquals(14, ResumeUtil.offset(STREAM)); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java new file mode 100644 index 000000000..b5625bf8e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java @@ -0,0 +1,190 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class ServerRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // resubscribe so a new connection is generated + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME frame + final ByteBuf resumeFrame = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame, transport.testConnection()); + resumeFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME_OK) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(61)); + + final ByteBuf resumeFrame1 = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame1, transport.testConnection()); + resumeFrame1.release(); + + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/ByteBufUtils.java b/rsocket-core/src/test/java/io/rsocket/test/util/ByteBufUtils.java new file mode 100644 index 000000000..9bed415ae --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/ByteBufUtils.java @@ -0,0 +1,32 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import java.util.concurrent.ThreadLocalRandom; + +public final class ByteBufUtils { + + private ByteBufUtils() {} + + public static ByteBuf getRandomByteBuf(int size) { + byte[] bytes = new byte[size]; + ThreadLocalRandom.current().nextBytes(bytes); + return Unpooled.wrappedBuffer(bytes); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java index 3e8a0e551..cdfcefdc8 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,58 +16,109 @@ package io.rsocket.test.util; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import org.reactivestreams.Publisher; -import reactor.core.publisher.DirectProcessor; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; public class LocalDuplexConnection implements DuplexConnection { - private final DirectProcessor send; - private final DirectProcessor receive; - private final MonoProcessor closeNotifier; + private final ByteBufAllocator allocator; + private final Sinks.Many send; + private final Sinks.Many receive; + private final Sinks.Empty onClose; private final String name; public LocalDuplexConnection( - String name, DirectProcessor send, DirectProcessor receive) { + String name, + ByteBufAllocator allocator, + Sinks.Many send, + Sinks.Many receive) { this.name = name; + this.allocator = allocator; this.send = send; this.receive = receive; - closeNotifier = MonoProcessor.create(); + this.onClose = Sinks.empty(); } @Override - public Mono send(Publisher frame) { - return Flux.from(frame) + public void sendFrame(int streamId, ByteBuf frame) { + System.out.println(name + " - " + frame.toString()); + send.tryEmitNext(frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + System.out.println(name + " - " + errorFrame.toString()); + send.tryEmitNext(errorFrame); + onClose.tryEmitEmpty(); + } + + @Override + public Flux receive() { + return receive + .asFlux() .doOnNext(f -> System.out.println(name + " - " + f.toString())) - .doOnNext(send::onNext) - .doOnError(send::onError) - .then(); + .transform( + Operators.lift( + (__, actual) -> + new CoreSubscriber() { + + @Override + public void onSubscribe(Subscription s) { + actual.onSubscribe(s); + } + + @Override + public void onNext(ByteBuf byteBuf) { + actual.onNext(byteBuf); + byteBuf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + })); + } + + @Override + public ByteBufAllocator alloc() { + return allocator; } @Override - public Flux receive() { - return receive.doOnNext(f -> System.out.println(name + " - " + f.toString())); + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress(name); } @Override - public double availability() { - return 1; + public void dispose() { + onClose.tryEmitEmpty(); } @Override - public Mono close() { - return Mono.defer( - () -> { - closeNotifier.onComplete(); - return Mono.empty(); - }); + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return closeNotifier; + return onClose.asMono(); } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java index e0ee1713c..a33c4c4b3 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,8 +16,7 @@ package io.rsocket.test.util; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.is; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; @@ -77,8 +76,13 @@ public double availability() { } @Override - public Mono close() { - return delegate.close(); + public void dispose() { + delegate.dispose(); + } + + @Override + public boolean isDisposed() { + return delegate.isDisposed(); } @Override @@ -111,6 +115,8 @@ public void assertMetadataPushCount(int expected) { } private static void assertCount(int expected, String type, AtomicInteger counter) { - assertThat("Unexpected invocations for " + type + '.', counter.get(), is(expected)); + assertThat(counter.get()) + .describedAs("Unexpected invocations for " + type + '.') + .isEqualTo(expected); } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/StringUtils.java b/rsocket-core/src/test/java/io/rsocket/test/util/StringUtils.java new file mode 100644 index 000000000..403eacb6d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/StringUtils.java @@ -0,0 +1,34 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +public final class StringUtils { + + private static final String CANDIDATE_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + + private StringUtils() {} + + public static String getRandomString(int size) { + return ThreadLocalRandom.current() + .ints(size, 0, CANDIDATE_CHARS.length()) + .mapToObj(index -> ((Character) CANDIDATE_CHARS.charAt(index)).toString()) + .collect(Collectors.joining()); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java new file mode 100644 index 000000000..f02bc99a4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java @@ -0,0 +1,43 @@ +package io.rsocket.test.util; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.transport.ClientTransport; +import java.time.Duration; +import reactor.core.publisher.Mono; + +public class TestClientTransport implements ClientTransport { + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "client"); + + private volatile TestDuplexConnection testDuplexConnection; + + int maxFrameLength = FRAME_LENGTH_MASK; + + @Override + public Mono connect() { + return Mono.fromSupplier(() -> testDuplexConnection = new TestDuplexConnection(allocator)); + } + + public TestDuplexConnection testConnection() { + return testDuplexConnection; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public TestClientTransport withMaxFrameLength(int maxFrameLength) { + this.maxFrameLength = maxFrameLength; + return this; + } + + @Override + public int maxFrameLength() { + return maxFrameLength; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java index 357ffda69..8793d6ca4 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,19 +16,26 @@ package io.rsocket.test.util; +import io.netty.buffer.ByteBuf; import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import java.util.Collection; -import java.util.concurrent.ConcurrentLinkedQueue; +import io.rsocket.RSocketErrorException; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; import reactor.core.publisher.DirectProcessor; import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; /** * An implementation of {@link DuplexConnection} that provides functionality to modify the behavior @@ -38,44 +45,88 @@ public class TestDuplexConnection implements DuplexConnection { private static final Logger logger = LoggerFactory.getLogger(TestDuplexConnection.class); - private final LinkedBlockingQueue sent; - private final DirectProcessor sentPublisher; - private final DirectProcessor received; - private final MonoProcessor close; - private final ConcurrentLinkedQueue> sendSubscribers; + private final LinkedBlockingQueue sent; + + private final DirectProcessor sentPublisher; + private final FluxSink sendSink; + private final DirectProcessor received; + private final FluxSink receivedSink; + private final MonoProcessor onClose; + private final LeaksTrackingByteBufAllocator allocator; private volatile double availability = 1; private volatile int initialSendRequestN = Integer.MAX_VALUE; - public TestDuplexConnection() { - sent = new LinkedBlockingQueue<>(); - received = DirectProcessor.create(); - sentPublisher = DirectProcessor.create(); - sendSubscribers = new ConcurrentLinkedQueue<>(); - close = MonoProcessor.create(); + public TestDuplexConnection(LeaksTrackingByteBufAllocator allocator) { + this.allocator = allocator; + this.sent = new LinkedBlockingQueue<>(); + this.received = DirectProcessor.create(); + this.receivedSink = received.sink(); + this.sentPublisher = DirectProcessor.create(); + this.sendSink = sentPublisher.sink(); + this.onClose = MonoProcessor.create(); } @Override - public Mono send(Publisher frames) { + public void sendFrame(int streamId, ByteBuf frame) { if (availability <= 0) { - return Mono.error( - new IllegalStateException("RSocket not available. Availability: " + availability)); + throw new IllegalStateException("RSocket not available. Availability: " + availability); } - Subscriber subscriber = TestSubscriber.create(initialSendRequestN); - Flux.from(frames) - .doOnNext( - frame -> { - sent.offer(frame); - sentPublisher.onNext(frame); - }) - .doOnError(throwable -> logger.error("Error in send stream on test connection.", throwable)) - .subscribe(subscriber); - sendSubscribers.add(subscriber); - return Mono.empty(); + + sendSink.next(frame); + sent.offer(frame); + } + + @Override + public Flux receive() { + return received.transform( + Operators.lift( + (__, actual) -> + new CoreSubscriber() { + @Override + public void onSubscribe(Subscription s) { + actual.onSubscribe(s); + } + + @Override + public void onNext(ByteBuf byteBuf) { + actual.onNext(byteBuf); + byteBuf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + })); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + sendSink.next(errorFrame); + sent.offer(errorFrame); + + final Throwable cause = e.getCause(); + if (cause == null) { + onClose.onComplete(); + } else { + onClose.onError(cause); + } + } + + @Override + public LeaksTrackingByteBufAllocator alloc() { + return allocator; } @Override - public Flux receive() { - return received; + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress("TestDuplexConnection"); } @Override @@ -84,47 +135,60 @@ public double availability() { } @Override - public Mono close() { - return close; + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); } @Override public Mono onClose() { - return close(); + return onClose; } - public Frame awaitSend() throws InterruptedException { - return sent.take(); + public boolean isEmpty() { + return sent.isEmpty(); + } + + @NonNull + public ByteBuf awaitFrame() { + try { + return sent.take(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public ByteBuf pollFrame() { + return sent.poll(); } public void setAvailability(double availability) { this.availability = availability; } - public Collection getSent() { + public BlockingQueue getSent() { return sent; } - public Publisher getSentAsPublisher() { + public Publisher getSentAsPublisher() { return sentPublisher; } - public void addToReceivedBuffer(Frame... received) { - for (Frame frame : received) { - this.received.onNext(frame); + public void addToReceivedBuffer(ByteBuf... received) { + for (ByteBuf frame : received) { + this.receivedSink.next(frame); } } public void clearSendReceiveBuffers() { sent.clear(); - sendSubscribers.clear(); } public void setInitialSendRequestN(int initialSendRequestN) { this.initialSendRequestN = initialSendRequestN; } - - public Collection> getSendSubscribers() { - return sendSubscribers; - } } diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java new file mode 100644 index 000000000..2dad2cc1f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java @@ -0,0 +1,46 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test.util; + +import java.net.SocketAddress; +import java.util.Objects; + +public final class TestLocalSocketAddress extends SocketAddress { + + private static final long serialVersionUID = 2608695156052100164L; + + private final String name; + + /** + * Creates a new instance. + * + * @param name the name representing the address + * @throws NullPointerException if {@code name} is {@code null} + */ + public TestLocalSocketAddress(String name) { + this.name = Objects.requireNonNull(name, "name must not be null"); + } + + /** Return the name for this connection. */ + public String getName() { + return name; + } + + @Override + public String toString() { + return "[local address] " + name; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java new file mode 100644 index 000000000..fa9331d3b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java @@ -0,0 +1,90 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test.util; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.transport.ServerTransport; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +public class TestServerTransport implements ServerTransport { + private final Sinks.One connSink = Sinks.one(); + private TestDuplexConnection connection; + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + int maxFrameLength = FRAME_LENGTH_MASK; + + @Override + public Mono start(ConnectionAcceptor acceptor) { + connSink + .asMono() + .flatMap(duplexConnection -> acceptor.apply(duplexConnection)) + .subscribe(ignored -> {}, err -> disposeConnection(), this::disposeConnection); + return Mono.just( + new Closeable() { + @Override + public Mono onClose() { + return connSink.asMono().then(); + } + + @Override + public void dispose() { + connSink.tryEmitEmpty(); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return connSink.scan(Scannable.Attr.TERMINATED) + || connSink.scan(Scannable.Attr.CANCELLED); + } + }); + } + + private void disposeConnection() { + TestDuplexConnection c = connection; + if (c != null) { + c.dispose(); + } + } + + public TestDuplexConnection connect() { + TestDuplexConnection c = new TestDuplexConnection(allocator); + connection = c; + connSink.tryEmitValue(c); + return c; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public TestServerTransport withMaxFrameLength(int maxFrameLength) { + this.maxFrameLength = maxFrameLength; + return this; + } + + @Override + public int maxFrameLength() { + return maxFrameLength; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestSubscriber.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestSubscriber.java index cc419a68a..e88b39648 100644 --- a/rsocket-core/src/test/java/io/rsocket/test/util/TestSubscriber.java +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestSubscriber.java @@ -1,3 +1,19 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.test.util; import static org.mockito.ArgumentMatchers.any; diff --git a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java b/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java deleted file mode 100644 index ba1dff548..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/TestUriHandler.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import java.net.URI; -import java.util.Optional; -import reactor.core.publisher.Mono; - -public class TestUriHandler implements UriHandler { - @Override - public Optional buildClient(URI uri) { - if ("test".equals(uri.getScheme())) { - return Optional.of(() -> Mono.just(new TestDuplexConnection())); - } - return UriHandler.super.buildClient(uri); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java b/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java deleted file mode 100644 index 5b90d64f0..000000000 --- a/rsocket-core/src/test/java/io/rsocket/uri/UriTransportRegistryTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.uri; - -import static org.junit.Assert.assertTrue; - -import io.rsocket.DuplexConnection; -import io.rsocket.test.util.TestDuplexConnection; -import io.rsocket.transport.ClientTransport; -import org.junit.Test; - -public class UriTransportRegistryTest { - @Test - public void testTestRegistered() { - ClientTransport test = UriTransportRegistry.clientForUri("test://test"); - - DuplexConnection duplexConnection = test.connect().block(); - - assertTrue(duplexConnection instanceof TestDuplexConnection); - } - - @Test(expected = UnsupportedOperationException.class) - public void testTestUnregistered() { - ClientTransport test = UriTransportRegistry.clientForUri("mailto://bonson@baulsupp.net"); - - test.connect().block(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java new file mode 100644 index 000000000..2ad944d09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java @@ -0,0 +1,64 @@ +package io.rsocket.util; + +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ByteBufPayloadTest { + + @Test + public void shouldIndicateThatItHasMetadata() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = ByteBufPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + ByteBufPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldThrowExceptionIfAccessAfterRelease() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.release()).isTrue(); + + Assertions.assertThatThrownBy(payload::hasMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::data).isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::metadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::touch) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(() -> payload.touch("test")) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getDataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java new file mode 100644 index 000000000..f04de78b6 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.Test; + +public class DefaultPayloadTest { + public static final String DATA_VAL = "data"; + public static final String METADATA_VAL = "metadata"; + + @Test + public void testReuse() { + Payload p = DefaultPayload.create(DATA_VAL, METADATA_VAL); + assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); + assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); + } + + public void assertDataAndMetadata(Payload p, String dataVal, String metadataVal) { + assertThat(p.getDataUtf8()).describedAs("Unexpected data.").isEqualTo(dataVal); + if (metadataVal == null) { + assertThat(p.hasMetadata()).describedAs("Non-null metadata").isEqualTo(false); + } else { + assertThat(p.hasMetadata()).describedAs("Null metadata").isEqualTo(true); + assertThat(p.getMetadataUtf8()).describedAs("Unexpected metadata.").isEqualTo(metadataVal); + } + } + + @Test + public void staticMethods() { + assertDataAndMetadata(DefaultPayload.create(DATA_VAL, METADATA_VAL), DATA_VAL, METADATA_VAL); + assertDataAndMetadata(DefaultPayload.create(DATA_VAL), DATA_VAL, null); + } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = DefaultPayload.create("data"); + + assertThat(payload.hasMetadata()).isFalse(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + DefaultPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata2() { + Payload payload = + DefaultPayload.create(ByteBuffer.wrap("data".getBytes()), ByteBuffer.allocate(0)); + + assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldReleaseGivenByteBufDataAndMetadataUpOnPayloadCreation() { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (byte i = 0; i < 126; i++) { + ByteBuf data = allocator.buffer(); + data.writeByte(i); + + boolean metadataPresent = ThreadLocalRandom.current().nextBoolean(); + ByteBuf metadata = null; + if (metadataPresent) { + metadata = allocator.buffer(); + metadata.writeByte(i + 1); + } + + Payload payload = DefaultPayload.create(data, metadata); + + assertThat(payload.getData()).isEqualTo(ByteBuffer.wrap(new byte[] {i})); + + assertThat(payload.getMetadata()) + .isEqualTo( + metadataPresent + ? ByteBuffer.wrap(new byte[] {(byte) (i + 1)}) + : DefaultPayload.EMPTY_BUFFER); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/ExceptionUtilTest.java b/rsocket-core/src/test/java/io/rsocket/util/ExceptionUtilTest.java deleted file mode 100644 index 06eea8975..000000000 --- a/rsocket-core/src/test/java/io/rsocket/util/ExceptionUtilTest.java +++ /dev/null @@ -1,26 +0,0 @@ -package io.rsocket.util; - -import static io.rsocket.util.ExceptionUtil.noStacktrace; -import static org.junit.Assert.assertEquals; - -import java.io.PrintWriter; -import java.io.StringWriter; -import org.junit.Test; - -public class ExceptionUtilTest { - @Test - public void testNoStacktrace() { - RuntimeException ex = noStacktrace(new RuntimeException("RE")); - assertEquals( - String.format( - "java.lang.RuntimeException: RE%n" - + "\tat java.lang.RuntimeException.(Unknown Source)%n"), - stacktraceString(ex)); - } - - private String stacktraceString(RuntimeException ex) { - StringWriter stringWriter = new StringWriter(); - ex.printStackTrace(new PrintWriter(stringWriter)); - return stringWriter.toString(); - } -} diff --git a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java new file mode 100644 index 000000000..46e0f77f4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java @@ -0,0 +1,187 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import static org.assertj.core.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class NumberUtilsTest { + + @DisplayName("returns int value with postitive int") + @Test + void requireNonNegativeInt() { + assertThat(NumberUtils.requireNonNegative(Integer.MAX_VALUE, "test-message")) + .isEqualTo(Integer.MAX_VALUE); + } + + @DisplayName( + "requireNonNegative with int argument throws IllegalArgumentException with negative value") + @Test + void requireNonNegativeIntNegative() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireNonNegative(Integer.MIN_VALUE, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requireNonNegative with int argument throws NullPointerException with null message") + @Test + void requireNonNegativeIntNullMessage() { + assertThatNullPointerException() + .isThrownBy(() -> NumberUtils.requireNonNegative(Integer.MIN_VALUE, null)) + .withMessage("message must not be null"); + } + + @DisplayName("requireNonNegative returns int value with zero") + @Test + void requireNonNegativeIntZero() { + assertThat(NumberUtils.requireNonNegative(0, "test-message")).isEqualTo(0); + } + + @DisplayName("requirePositive returns int value with positive int") + @Test + void requirePositiveInt() { + assertThat(NumberUtils.requirePositive(Integer.MAX_VALUE, "test-message")) + .isEqualTo(Integer.MAX_VALUE); + } + + @DisplayName( + "requirePositive with int argument throws IllegalArgumentException with negative value") + @Test + void requirePositiveIntNegative() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(Integer.MIN_VALUE, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requirePositive with int argument throws NullPointerException with null message") + @Test + void requirePositiveIntNullMessage() { + assertThatNullPointerException() + .isThrownBy(() -> NumberUtils.requirePositive(Integer.MIN_VALUE, null)) + .withMessage("message must not be null"); + } + + @DisplayName("requirePositive with int argument throws IllegalArgumentException with zero value") + @Test + void requirePositiveIntZero() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(0, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requirePositive returns long value with positive long") + @Test + void requirePositiveLong() { + assertThat(NumberUtils.requirePositive(Long.MAX_VALUE, "test-message")) + .isEqualTo(Long.MAX_VALUE); + } + + @DisplayName( + "requirePositive with long argument throws IllegalArgumentException with negative value") + @Test + void requirePositiveLongNegative() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(Long.MIN_VALUE, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requirePositive with long argument throws NullPointerException with null message") + @Test + void requirePositiveLongNullMessage() { + assertThatNullPointerException() + .isThrownBy(() -> NumberUtils.requirePositive(Long.MIN_VALUE, null)) + .withMessage("message must not be null"); + } + + @DisplayName("requirePositive with long argument throws IllegalArgumentException with zero value") + @Test + void requirePositiveLongZero() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(0L, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requireUnsignedByte returns length if 255") + @Test + void requireUnsignedByte() { + assertThat(NumberUtils.requireUnsignedByte((1 << 8) - 1)).isEqualTo(255); + } + + @DisplayName("requireUnsignedByte throws IllegalArgumentException if larger than 255") + @Test + void requireUnsignedByteOverFlow() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireUnsignedByte(1 << 8)) + .withMessage("%d is larger than 8 bits", 1 << 8); + } + + @DisplayName("requireUnsignedMedium returns length if 16_777_215") + @Test + void requireUnsignedMedium() { + assertThat(NumberUtils.requireUnsignedMedium((1 << 24) - 1)).isEqualTo(16_777_215); + } + + @DisplayName("requireUnsignedMedium throws IllegalArgumentException if larger than 16_777_215") + @Test + void requireUnsignedMediumOverFlow() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireUnsignedMedium(1 << 24)) + .withMessage("%d is larger than 24 bits", 1 << 24); + } + + @DisplayName("requireUnsignedShort returns length if 65_535") + @Test + void requireUnsignedShort() { + assertThat(NumberUtils.requireUnsignedShort((1 << 16) - 1)).isEqualTo(65_535); + } + + @DisplayName("requireUnsignedShort throws IllegalArgumentException if larger than 65_535") + @Test + void requireUnsignedShortOverFlow() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireUnsignedShort(1 << 16)) + .withMessage("%d is larger than 16 bits", 1 << 16); + } + + @Test + void encodeUnsignedMedium() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 129); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(129); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(129); + } + + @Test + void encodeUnsignedMediumLarge() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 0xFFFFFC); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(16777212); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(-4); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/PayloadImplTest.java b/rsocket-core/src/test/java/io/rsocket/util/PayloadImplTest.java deleted file mode 100644 index 9099efa9a..000000000 --- a/rsocket-core/src/test/java/io/rsocket/util/PayloadImplTest.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - *

- * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - *

- * http://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.rsocket.util; - -import static io.rsocket.util.PayloadImpl.textPayload; -import static org.hamcrest.MatcherAssert.*; -import static org.hamcrest.Matchers.*; - -import io.rsocket.Payload; -import javax.annotation.Nullable; -import org.junit.Test; - -public class PayloadImplTest { - public static final String DATA_VAL = "data"; - public static final String METADATA_VAL = "metadata"; - - @Test - public void testReuse() { - PayloadImpl p = new PayloadImpl(DATA_VAL, METADATA_VAL); - assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); - assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); - } - - @Test - public void testReuseWithExternalMark() { - PayloadImpl p = new PayloadImpl(DATA_VAL, METADATA_VAL); - assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); - p.getData().position(2).mark(); - assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); - } - - public void assertDataAndMetadata(Payload p, String dataVal, @Nullable String metadataVal) { - assertThat("Unexpected data.", p.getDataUtf8(), equalTo(dataVal)); - if (metadataVal == null) { - assertThat("Non-null metadata", p.hasMetadata(), equalTo(false)); - } else { - assertThat("Null metadata", p.hasMetadata(), equalTo(true)); - assertThat("Unexpected metadata.", p.getMetadataUtf8(), equalTo(metadataVal)); - } - } - - @Test - public void staticMethods() { - assertDataAndMetadata(textPayload(DATA_VAL, METADATA_VAL), DATA_VAL, METADATA_VAL); - assertDataAndMetadata(textPayload(DATA_VAL), DATA_VAL, null); - } -} diff --git a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 4a10f7ed6..000000000 --- a/rsocket-core/src/test/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.uri.TestUriHandler diff --git a/rsocket-examples/src/main/resources/log4j.properties b/rsocket-core/src/test/resources/META-INF/services/org.assertj.core.presentation.Representation similarity index 61% rename from rsocket-examples/src/main/resources/log4j.properties rename to rsocket-core/src/test/resources/META-INF/services/org.assertj.core.presentation.Representation index 6bd4c8540..9ac418a0c 100644 --- a/rsocket-examples/src/main/resources/log4j.properties +++ b/rsocket-core/src/test/resources/META-INF/services/org.assertj.core.presentation.Representation @@ -1,11 +1,11 @@ # -# Copyright 2016 Netflix, Inc. +# Copyright 2015-2018 the original author or authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -13,8 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] - %m%n \ No newline at end of file +io.rsocket.frame.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension new file mode 100644 index 000000000..2b51ba0de --- /dev/null +++ b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension @@ -0,0 +1 @@ +io.rsocket.frame.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-core/src/test/resources/log4j.properties b/rsocket-core/src/test/resources/log4j.properties deleted file mode 100644 index 0ded7169b..000000000 --- a/rsocket-core/src/test/resources/log4j.properties +++ /dev/null @@ -1,34 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -#

-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -#

-# http://www.apache.org/licenses/LICENSE-2.0 -#

-# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. -# - - -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] (%F:%L) - %m%n -#log4j.logger.io.rsocket.FrameLogger=Info \ No newline at end of file diff --git a/rsocket-core/src/test/resources/logback-test.xml b/rsocket-core/src/test/resources/logback-test.xml new file mode 100644 index 000000000..9081698fb --- /dev/null +++ b/rsocket-core/src/test/resources/logback-test.xml @@ -0,0 +1,32 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + + + + + + diff --git a/rsocket-examples/build.gradle b/rsocket-examples/build.gradle index 4e2178662..4059eb957 100644 --- a/rsocket-examples/build.gradle +++ b/rsocket-examples/build.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,11 +14,37 @@ * limitations under the License. */ +plugins { + id 'java' +} + dependencies { - compile project(':rsocket-core') - compile project(':rsocket-spectator') - compile project(':rsocket-transport-netty') - compile project(':rsocket-transport-local') + implementation project(':rsocket-core') + implementation project(':rsocket-load-balancer') + implementation project(':rsocket-transport-local') + implementation project(':rsocket-transport-netty') + + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") + + implementation 'com.netflix.concurrency-limits:concurrency-limits-core' + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") - testCompile project(':rsocket-test') + runtimeOnly 'ch.qos.logback:logback-classic' + + testImplementation project(':rsocket-test') + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.mockito:mockito-core' + testImplementation 'org.assertj:assertj-core' + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.awaitility:awaitility' + testImplementation "io.micrometer:micrometer-test" + testImplementation "io.micrometer:micrometer-tracing-integration-test" + + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } + +description = 'Example usage of the RSocket library' diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java index 4ee9d33c2..463043020 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java @@ -1,69 +1,61 @@ /* - * Copyright 2016 Netflix, Inc. - *

- * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - *

- * http://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.examples.transport.tcp.channel; -import io.rsocket.AbstractRSocket; -import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.DefaultPayload; import java.time.Duration; -import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public final class ChannelEchoClient { + private static final Logger logger = LoggerFactory.getLogger(ChannelEchoClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor(new SocketAcceptorImpl()) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() - .subscribe(); + + SocketAcceptor echoAcceptor = + SocketAcceptor.forRequestChannel( + payloads -> + Flux.from(payloads) + .map(Payload::getDataUtf8) + .map(s -> "Echo: " + s) + .map(DefaultPayload::create)); + + RSocketServer.create(echoAcceptor).bindNow(TcpServerTransport.create("localhost", 7000)); RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); socket - .requestChannel(Flux.interval(Duration.ofMillis(1000)).map(i -> new PayloadImpl("Hello"))) + .requestChannel( + Flux.interval(Duration.ofMillis(1000)).map(i -> DefaultPayload.create("Hello"))) .map(Payload::getDataUtf8) - .doOnNext(System.out::println) + .doOnNext(logger::debug) .take(10) - .thenEmpty(socket.close()) + .doFinally(signalType -> socket.dispose()) + .then() .block(); } - - private static class SocketAcceptorImpl implements SocketAcceptor { - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.from(payloads) - .map(Payload::getDataUtf8) - .map(s -> "Echo: " + s) - .map(PayloadImpl::new); - } - }); - } - } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java new file mode 100644 index 000000000..dfbbcde53 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java @@ -0,0 +1,55 @@ +package io.rsocket.examples.transport.tcp.client; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketClientExample { + static final Logger logger = LoggerFactory.getLogger(RSocketClientExample.class); + + public static void main(String[] args) { + + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + String data = p.getDataUtf8(); + logger.info("Received request data {}", data); + + Payload responsePayload = DefaultPayload.create("Echo: " + data); + p.release(); + + return Mono.just(responsePayload); + })) + .bind(TcpServerTransport.create("localhost", 7000)) + .delaySubscription(Duration.ofSeconds(5)) + .doOnNext(cc -> logger.info("Server started on the address : {}", cc.address())) + .block(); + + Mono source = + RSocketConnector.create() + .reconnect(Retry.backoff(50, Duration.ofMillis(500))) + .connect(TcpClientTransport.create("localhost", 7000)); + + RSocketClient.from(source) + .requestResponse(Mono.just(DefaultPayload.create("Test Request"))) + .doOnSubscribe(s -> logger.info("Executing Request")) + .doOnNext( + d -> { + logger.info("Received response data {}", d.getDataUtf8()); + d.release(); + }) + .repeat(10) + .blockLast(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java deleted file mode 100644 index 59ceb6e35..000000000 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/duplex/DuplexClient.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2017 Netflix, Inc. - *

- * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - *

- * http://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.rsocket.examples.transport.tcp.duplex; - -import io.rsocket.AbstractRSocket; -import io.rsocket.Payload; -import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.util.PayloadImpl; -import java.time.Duration; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -public final class DuplexClient { - - public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( - (setup, reactiveSocket) -> { - reactiveSocket - .requestStream(new PayloadImpl("Hello-Bidi")) - .map(Payload::getDataUtf8) - .log() - .subscribe(); - - return Mono.just(new AbstractRSocket() {}); - }) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() - .subscribe(); - - RSocket socket = - RSocketFactory.connect() - .acceptor( - rSocket -> - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.interval(Duration.ofSeconds(1)) - .map(aLong -> new PayloadImpl("Bi-di Response => " + aLong)); - } - }) - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); - - socket.onClose().block(); - } -} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java new file mode 100644 index 000000000..89b22749f --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java @@ -0,0 +1,237 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.examples.transport.tcp.fnf; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadLocalRandom; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.concurrent.Queues; + +/** + * An example of long-running tasks processing (a.k.a Kafka style) where a client submits tasks over + * request `FireAndForget` and then receives results over the same method but on it is own side. + * + *

This example shows a case when the client may disappear, however, another a client can connect + * again and receive undelivered completed tasks remaining for the previous one. + */ +public class TaskProcessingWithServerSideNotificationsExample { + + public static void main(String[] args) throws InterruptedException { + Sinks.Many tasksProcessor = + Sinks.many().unicast().onBackpressureBuffer(Queues.unboundedMultiproducer().get()); + ConcurrentMap> idToCompletedTasksMap = new ConcurrentHashMap<>(); + ConcurrentMap idToRSocketMap = new ConcurrentHashMap<>(); + BackgroundWorker backgroundWorker = + new BackgroundWorker(tasksProcessor.asFlux(), idToCompletedTasksMap, idToRSocketMap); + + RSocketServer.create(new TasksAcceptor(tasksProcessor, idToCompletedTasksMap, idToRSocketMap)) + .bindNow(TcpServerTransport.create(9991)); + + Logger logger = LoggerFactory.getLogger("RSocket.Client.ID[Test]"); + + Mono rSocketMono = + RSocketConnector.create() + .setupPayload(DefaultPayload.create("Test")) + .acceptor( + SocketAcceptor.forFireAndForget( + p -> { + logger.info("Received Processed Task[{}]", p.getDataUtf8()); + p.release(); + return Mono.empty(); + })) + .connect(TcpClientTransport.create(9991)); + + RSocket rSocketRequester1 = rSocketMono.block(); + + for (int i = 0; i < 10; i++) { + rSocketRequester1.fireAndForget(DefaultPayload.create("task" + i)).block(); + } + + Thread.sleep(4000); + + rSocketRequester1.dispose(); + logger.info("Disposed"); + + Thread.sleep(4000); + + RSocket rSocketRequester2 = rSocketMono.block(); + + logger.info("Reconnected"); + + Thread.sleep(10000); + } + + static class BackgroundWorker extends BaseSubscriber { + final ConcurrentMap> idToCompletedTasksMap; + final ConcurrentMap idToRSocketMap; + + BackgroundWorker( + Flux taskProducer, + ConcurrentMap> idToCompletedTasksMap, + ConcurrentMap idToRSocketMap) { + + this.idToCompletedTasksMap = idToCompletedTasksMap; + this.idToRSocketMap = idToRSocketMap; + + // mimic a long running task processing + taskProducer + .concatMap( + t -> + Mono.delay(Duration.ofMillis(ThreadLocalRandom.current().nextInt(200, 2000))) + .thenReturn(t)) + .subscribe(this); + } + + @Override + protected void hookOnNext(Task task) { + BlockingQueue completedTasksQueue = + idToCompletedTasksMap.computeIfAbsent(task.id, __ -> new LinkedBlockingQueue<>()); + + completedTasksQueue.offer(task); + RSocket rSocket = idToRSocketMap.get(task.id); + if (rSocket != null) { + rSocket + .fireAndForget(DefaultPayload.create(task.content)) + .subscribe(null, e -> {}, () -> completedTasksQueue.remove(task)); + } + } + } + + static class TasksAcceptor implements SocketAcceptor { + + static final Logger logger = LoggerFactory.getLogger(TasksAcceptor.class); + + final Sinks.Many tasksToProcess; + final ConcurrentMap> idToCompletedTasksMap; + final ConcurrentMap idToRSocketMap; + + TasksAcceptor( + Sinks.Many tasksToProcess, + ConcurrentMap> idToCompletedTasksMap, + ConcurrentMap idToRSocketMap) { + this.tasksToProcess = tasksToProcess; + this.idToCompletedTasksMap = idToCompletedTasksMap; + this.idToRSocketMap = idToRSocketMap; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + String id = setup.getDataUtf8(); + logger.info("Accepting a new client connection with ID {}", id); + // sendingRSocket represents here an RSocket requester to a remote peer + + if (this.idToRSocketMap.compute( + id, (__, old) -> old == null || old.isDisposed() ? sendingSocket : old) + == sendingSocket) { + return Mono.just( + new RSocketTaskHandler(idToRSocketMap, tasksToProcess, id, sendingSocket)) + .doOnSuccess(__ -> checkTasksToDeliver(sendingSocket, id)); + } + + return Mono.error( + new IllegalStateException("There is already a client connected with the same ID")); + } + + private void checkTasksToDeliver(RSocket sendingSocket, String id) { + logger.info("Accepted a new client connection with ID {}. Checking for remaining tasks", id); + BlockingQueue tasksToDeliver = this.idToCompletedTasksMap.get(id); + + if (tasksToDeliver == null || tasksToDeliver.isEmpty()) { + // means nothing yet to send + return; + } + + logger.info("Found remaining tasks to deliver for client {}", id); + + for (; ; ) { + Task task = tasksToDeliver.poll(); + + if (task == null) { + return; + } + + sendingSocket + .fireAndForget(DefaultPayload.create(task.content)) + .subscribe( + null, + e -> { + // offers back a task if it has not been delivered + tasksToDeliver.offer(task); + }); + } + } + + private static class RSocketTaskHandler implements RSocket { + + private final String id; + private final RSocket sendingSocket; + private ConcurrentMap idToRSocketMap; + private Sinks.Many tasksToProcess; + + public RSocketTaskHandler( + ConcurrentMap idToRSocketMap, + Sinks.Many tasksToProcess, + String id, + RSocket sendingSocket) { + this.id = id; + this.sendingSocket = sendingSocket; + this.idToRSocketMap = idToRSocketMap; + this.tasksToProcess = tasksToProcess; + } + + @Override + public Mono fireAndForget(Payload payload) { + logger.info("Received a Task[{}] from Client.ID[{}]", payload.getDataUtf8(), id); + Sinks.EmitResult result = tasksToProcess.tryEmitNext(new Task(id, payload.getDataUtf8())); + payload.release(); + return result.isFailure() ? Mono.error(new Sinks.EmissionException(result)) : Mono.empty(); + } + + @Override + public void dispose() { + idToRSocketMap.remove(id, sendingSocket); + } + } + } + + static class Task { + final String id; + final String content; + + Task(String id, String content) { + this.id = id; + this.content = content; + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java new file mode 100644 index 000000000..272caf7a0 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java @@ -0,0 +1,144 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class LeaseManager implements Runnable { + + static final Logger logger = LoggerFactory.getLogger(LeaseManager.class); + + volatile int activeConnectionsCount; + static final AtomicIntegerFieldUpdater ACTIVE_CONNECTIONS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "activeConnectionsCount"); + + volatile int stateAndInFlight; + static final AtomicIntegerFieldUpdater STATE_AND_IN_FLIGHT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "stateAndInFlight"); + + static final int MASK_PAUSED = 0b1_000_0000_0000_0000_0000_0000_0000_0000; + static final int MASK_IN_FLIGHT = 0b0_111_1111_1111_1111_1111_1111_1111_1111; + + final BlockingDeque sendersQueue = new LinkedBlockingDeque<>(); + final Scheduler worker = Schedulers.newSingle(LeaseManager.class.getName()); + + final int capacity; + final int ttl; + + public LeaseManager(int capacity, int ttl) { + this.capacity = capacity; + this.ttl = ttl; + } + + @Override + public void run() { + try { + LimitBasedLeaseSender leaseSender = sendersQueue.poll(); + + if (leaseSender == null) { + return; + } + + if (leaseSender.isDisposed()) { + logger.debug("Connection[" + leaseSender.connectionId + "]: LeaseSender is Disposed"); + worker.schedule(this); + return; + } + + int limit = leaseSender.limitAlgorithm.getLimit(); + + if (limit == 0) { + throw new IllegalStateException("Limit is 0"); + } + + if (pauseIfNoCapacity()) { + sendersQueue.addFirst(leaseSender); + logger.debug("Pause execution. Not enough capacity"); + return; + } + + leaseSender.sendLease(ttl, limit); + sendersQueue.offer(leaseSender); + + int activeConnections = activeConnectionsCount; + int nextDelay = activeConnections == 0 ? ttl : (ttl / activeConnections); + + logger.debug("Next check happens in " + nextDelay + "ms"); + + worker.schedule(this, nextDelay, TimeUnit.MILLISECONDS); + } catch (Throwable e) { + logger.error("LeaseSender failed to send lease", e); + } + } + + int incrementInFlightAndGet() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight + 1; + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return nextInFlight; + } + } + } + + void decrementInFlight() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight - 1; + + if (inFlight == capacity && paused == MASK_PAUSED) { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight)) { + logger.debug("Resume execution"); + worker.schedule(this); + return; + } + } else { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return; + } + } + } + } + + boolean pauseIfNoCapacity() { + int capacity = this.capacity; + for (; ; ) { + int inFlight = stateAndInFlight; + + if (inFlight < capacity) { + return false; + } + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, inFlight, inFlight | MASK_PAUSED)) { + return true; + } + } + } + + void unregister() { + ACTIVE_CONNECTIONS_COUNT.decrementAndGet(this); + } + + void register(LimitBasedLeaseSender sender) { + sendersQueue.offer(sender); + final int activeCount = ACTIVE_CONNECTIONS_COUNT.getAndIncrement(this); + + if (activeCount == 0) { + worker.schedule(this); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java new file mode 100644 index 000000000..8e1b27823 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java @@ -0,0 +1,54 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.rsocket.lease.Lease; +import io.rsocket.lease.TrackingLeaseSender; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import reactor.util.concurrent.Queues; + +public class LimitBasedLeaseSender extends LimitBasedStatsCollector implements TrackingLeaseSender { + + static final Logger logger = LoggerFactory.getLogger(LimitBasedLeaseSender.class); + + final String connectionId; + final Sinks.Many sink = + Sinks.many().unicast().onBackpressureBuffer(Queues.one().get()); + + public LimitBasedLeaseSender( + String connectionId, LeaseManager leaseManager, Limit limitAlgorithm) { + super(leaseManager, limitAlgorithm); + this.connectionId = connectionId; + } + + @Override + public Flux send() { + logger.info("Received new leased Connection[" + connectionId + "]"); + + leaseManager.register(this); + + return sink.asFlux(); + } + + public void sendLease(int ttl, int amount) { + final Lease nextLease = Lease.create(Duration.ofMillis(ttl), amount); + final Sinks.EmitResult result = sink.tryEmitNext(nextLease); + + if (result.isFailure()) { + logger.warn( + "Connection[" + + connectionId + + "]. Issued Lease: [" + + nextLease + + "] was not sent due to " + + result); + } else { + if (logger.isDebugEnabled()) { + logger.debug("To Connection[" + connectionId + "]: Issued Lease: [" + nextLease + "]"); + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java new file mode 100644 index 000000000..7f639ab87 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java @@ -0,0 +1,73 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.LongSupplier; +import reactor.util.annotation.Nullable; + +public class LimitBasedStatsCollector extends AtomicBoolean implements RequestInterceptor { + + final LeaseManager leaseManager; + final Limit limitAlgorithm; + + final ConcurrentMap inFlightMap = new ConcurrentHashMap<>(); + final ConcurrentMap timeMap = new ConcurrentHashMap<>(); + + final LongSupplier clock = System::nanoTime; + + public LimitBasedStatsCollector(LeaseManager leaseManager, Limit limitAlgorithm) { + this.leaseManager = leaseManager; + this.limitAlgorithm = limitAlgorithm; + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + long startTime = clock.getAsLong(); + + int currentInFlight = leaseManager.incrementInFlightAndGet(); + + inFlightMap.put(streamId, currentInFlight); + timeMap.put(streamId, startTime); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) {} + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, t != null); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, true); + } + + @Override + public boolean isDisposed() { + return get(); + } + + @Override + public void dispose() { + if (!getAndSet(true)) { + leaseManager.unregister(); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java new file mode 100644 index 000000000..a18dd9484 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java @@ -0,0 +1,27 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +// emulating a worker that process data from the queue +public class Task implements Runnable { + private static final Logger logger = LoggerFactory.getLogger(Task.class); + + final String message; + final int processingTime; + + Task(String message, int processingTime) { + this.message = message; + this.processingTime = processingTime; + } + + @Override + public void run() { + logger.info("Processing Task[{}]", message); + try { + Thread.sleep(processingTime); // emulating processing + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java new file mode 100644 index 000000000..cbecadfc3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java @@ -0,0 +1,44 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; + +public class TasksHandlingRSocket implements RSocket { + + private static final Logger logger = LoggerFactory.getLogger(TasksHandlingRSocket.class); + + final Disposable terminatable; + final Scheduler workScheduler; + final int processingTime; + + public TasksHandlingRSocket(Disposable terminatable, Scheduler scheduler, int processingTime) { + this.terminatable = terminatable; + this.workScheduler = scheduler; + this.processingTime = processingTime; + } + + @Override + public Mono fireAndForget(Payload payload) { + + // specifically to show that lease can limit rate of fnf requests in + // that example + String message = payload.getDataUtf8(); + payload.release(); + + return Mono.fromRunnable(new Task(message, processingTime)) + // schedule task on specific, limited in size scheduler + .subscribeOn(workScheduler) + // if errors - terminates server + .doOnError( + t -> { + logger.error("Queue has been overflowed. Terminating server"); + terminatable.dispose(); + System.exit(9); + }); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java new file mode 100644 index 000000000..30eb0c0e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Comparator; +import java.util.concurrent.PriorityBlockingQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class RequestingServer { + + private static final Logger logger = LoggerFactory.getLogger(RequestingServer.class); + + public static void main(String[] args) { + PriorityBlockingQueue rSockets = + new PriorityBlockingQueue<>( + 16, Comparator.comparingDouble(RSocket::availability).reversed()); + + CloseableChannel server = + RSocketServer.create( + (setup, sendingSocket) -> { + logger.info("Received new connection"); + return Mono.just(new RSocket() {}) + .doAfterTerminate(() -> rSockets.put(sendingSocket)); + }) + .lease(spec -> spec.maxPendingRequests(Integer.MAX_VALUE)) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + logger.info("Server started on port {}", server.address().getPort()); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .flatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + + return Mono.fromCallable( + () -> { + RSocket rSocket = rSockets.take(); + rSockets.offer(rSocket); + return rSocket; + }) + .flatMap( + clientRSocket -> + clientRSocket.fireAndForget(ByteBufPayload.create("" + tick))) + .retry(); + }) + .blockLast(); + + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java new file mode 100644 index 000000000..4a06855b2 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java @@ -0,0 +1,67 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.client.TcpClientTransport; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingClient { + private static final Logger logger = LoggerFactory.getLogger(RespondingClient.class); + + public static final int PROCESSING_TASK_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager periodicLeaseSender = + new LeaseManager(CONCURRENT_WORKERS_COUNT, PROCESSING_TASK_TIME); + + Disposable.Composite disposable = Disposables.composite(); + RSocket clientRSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, PROCESSING_TASK_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + periodicLeaseSender, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + disposable.add(clientRSocket); + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java new file mode 100644 index 000000000..c2fde38e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java @@ -0,0 +1,41 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class RequestingClient { + private static final Logger logger = LoggerFactory.getLogger(RequestingClient.class); + + public static void main(String[] args) { + + RSocket clientRSocket = + RSocketConnector.create() + .lease() + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .concatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); + }) + .blockLast(); + + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java new file mode 100644 index 000000000..b54330450 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketServer; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingServer { + + private static final Logger logger = LoggerFactory.getLogger(RespondingServer.class); + + public static final int TASK_PROCESSING_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager leaseManager = new LeaseManager(CONCURRENT_WORKERS_COUNT, TASK_PROCESSING_TIME); + + Disposable.Composite disposable = Disposables.composite(); + CloseableChannel server = + RSocketServer.create( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, TASK_PROCESSING_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + leaseManager, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + disposable.add(server); + + logger.info("Server started on port {}", server.address().getPort()); + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java new file mode 100644 index 000000000..c54335ccc --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java @@ -0,0 +1,160 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.simple; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.lease.Lease; +import io.rsocket.lease.LeaseSender; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class LeaseExample { + + private static final Logger logger = LoggerFactory.getLogger(LeaseExample.class); + + private static final String SERVER_TAG = "server"; + private static final String CLIENT_TAG = "client"; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + + int queueCapacity = 50; + BlockingQueue messagesQueue = new ArrayBlockingQueue<>(queueCapacity); + + // emulating a worker that process data from the queue + Thread workerThread = + new Thread( + () -> { + try { + while (!Thread.currentThread().isInterrupted()) { + String message = messagesQueue.take(); + logger.info("Process message {}", message); + Thread.sleep(500); // emulating processing + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + workerThread.start(); + + CloseableChannel server = + RSocketServer.create( + (setup, sendingSocket) -> + Mono.just( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + // add element. if overflows errors and terminates execution + // specifically to show that lease can limit rate of fnf requests in + // that example + try { + if (!messagesQueue.offer(payload.getDataUtf8())) { + logger.error("Queue has been overflowed. Terminating execution"); + sendingSocket.dispose(); + workerThread.interrupt(); + } + } finally { + payload.release(); + } + return Mono.empty(); + } + })) + .lease(leases -> leases.sender(new LeaseCalculator(SERVER_TAG, messagesQueue))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket clientRSocket = + RSocketConnector.create() + .lease((config) -> config.maxPendingRequests(1)) + .connect(TcpClientTransport.create(server.address())) + .block(); + + Objects.requireNonNull(clientRSocket); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + // here we wait for the first lease for the responder side and start execution + // on if there is allowance + .concatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); + }) + .blockLast(); + + clientRSocket.onClose().block(); + server.dispose(); + } + + /** + * This is a class responsible for making decision on whether Responder is ready to receive new + * FireAndForget or not base in the number of messages enqueued.
+ * In the nutshell this is responder-side rate-limiter logic which is created for every new + * connection.
+ * In real-world projects this class has to issue leases based on real metrics + */ + private static class LeaseCalculator implements LeaseSender { + final String tag; + final BlockingQueue queue; + + public LeaseCalculator(String tag, BlockingQueue queue) { + this.tag = tag; + this.queue = queue; + } + + @Override + public Flux send() { + Duration ttlDuration = Duration.ofSeconds(5); + // The interval function is used only for the demo purpose and should not be + // considered as the way to issue leases. + // For advanced RateLimiting with Leasing + // consider adopting https://github.com/Netflix/concurrency-limits#server-limiter + return Flux.interval(Duration.ZERO, ttlDuration.dividedBy(2)) + .handle( + (__, sink) -> { + // put queue.remainingCapacity() + 1 here if you want to observe that app is + // terminated because of the queue overflowing + int requests = queue.remainingCapacity(); + + // reissue new lease only if queue has remaining capacity to + // accept more requests + if (requests > 0) { + sink.next(Lease.create(ttlDuration, requests)); + } + }); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java new file mode 100644 index 000000000..abed4a52d --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.examples.transport.tcp.loadbalancer; + +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketServer; +import io.rsocket.loadbalance.LoadbalanceRSocketClient; +import io.rsocket.loadbalance.LoadbalanceTarget; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class RoundRobinRSocketLoadbalancerExample { + + public static void main(String[] args) { + CloseableChannel server1 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 1 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 1 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8080)); + + CloseableChannel server2 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 2 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 2 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8081)); + + CloseableChannel server3 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 3 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 3 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8082)); + + LoadbalanceTarget target8080 = LoadbalanceTarget.from("8080", TcpClientTransport.create(8080)); + LoadbalanceTarget target8081 = LoadbalanceTarget.from("8081", TcpClientTransport.create(8081)); + LoadbalanceTarget target8082 = LoadbalanceTarget.from("8082", TcpClientTransport.create(8082)); + + Flux> producer = + Flux.interval(Duration.ofSeconds(5)) + .log() + .map( + i -> { + int val = i.intValue(); + switch (val) { + case 0: + return Collections.emptyList(); + case 1: + return Collections.singletonList(target8080); + case 2: + return Arrays.asList(target8080, target8081); + case 3: + return Arrays.asList(target8080, target8082); + case 4: + return Arrays.asList(target8081, target8082); + case 5: + return Arrays.asList(target8080, target8081, target8082); + case 6: + return Collections.emptyList(); + case 7: + return Collections.emptyList(); + default: + return Arrays.asList(target8080, target8081, target8082); + } + }); + + RSocketClient rSocketClient = + LoadbalanceRSocketClient.builder(producer).roundRobinLoadbalanceStrategy().build(); + + for (int i = 0; i < 10000; i++) { + try { + rSocketClient.requestResponse(Mono.just(DefaultPayload.create("test" + i))).block(); + } catch (Throwable t) { + // no ops + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java new file mode 100644 index 000000000..a0a02a946 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java @@ -0,0 +1,102 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class CompositeMetadataExample { + static final Logger logger = LoggerFactory.getLogger(CompositeMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that every metadata payload will be encoded using + // CompositeMetadata layout as specified in the following subspec + // https://github.com/rsocket/rsocket/blob/master/Extensions/CompositeMetadata.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + final CompositeByteBuf compositeMetadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadata, + ByteBufAllocator.DEFAULT, + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, + routeMetadata); + + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), compositeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + + for (CompositeMetadata.Entry metadatum : compositeMetadata) { + if (Objects.requireNonNull(metadatum.getMimeType()) + .equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString())) { + return new RoutingMetadata(metadatum.getContent()).iterator().next(); + } + } + + return null; + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java new file mode 100644 index 000000000..2aee18bf9 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java @@ -0,0 +1,83 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class RoutingMetadataExample { + static final Logger logger = LoggerFactory.getLogger(RoutingMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that route will be encoded using + // Routing&Tagging Metadata layout specified at this + // subspec https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), routeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final RoutingMetadata routingMetadata = new RoutingMetadata(metadata); + + return routingMetadata.iterator().next(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java new file mode 100644 index 000000000..5491a1aab --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java @@ -0,0 +1,83 @@ +package io.rsocket.examples.transport.tcp.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.plugins.LimitRateInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class LimitRateInterceptorExample { + + private static final Logger logger = LoggerFactory.getLogger(LimitRateInterceptorExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.interval(Duration.ofMillis(100)) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)); + } + })) + .interceptors(registry -> registry.forResponder(LimitRateInterceptor.forResponder(64))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(LimitRateInterceptor.forRequester(64))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + logger.debug( + "\n\nStart of requestStream interaction\n" + "----------------------------------\n"); + + socket + .requestStream(DefaultPayload.create("Hello")) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .block(); + + logger.debug( + "\n\nStart of requestChannel interaction\n" + "-----------------------------------\n"); + + socket + .requestChannel( + Flux.generate( + () -> 1L, + (s, sink) -> { + sink.next(DefaultPayload.create("Next " + s)); + return ++s; + }) + .doOnRequest(e -> logger.debug("Client publisher receives request for " + e))) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .doFinally(signalType -> socket.dispose()) + .then() + .block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java index 258a997bd..0c372d2d8 100644 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,66 +16,54 @@ package io.rsocket.examples.transport.tcp.requestresponse; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.DefaultPayload; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Mono; public final class HelloWorldClient { + private static final Logger logger = LoggerFactory.getLogger(HelloWorldClient.class); + public static void main(String[] args) { - RSocketFactory.receive() - .acceptor( - (setupPayload, reactiveSocket) -> - Mono.just( - new AbstractRSocket() { - boolean fail = true; - @Override - public Mono requestResponse(Payload p) { - if (fail) { - fail = false; - return Mono.error(new Throwable()); - } else { - return Mono.just(p); - } - } - })) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() - .subscribe(); + RSocket rsocket = + new RSocket() { + boolean fail = true; - RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); + @Override + public Mono requestResponse(Payload p) { + if (fail) { + fail = false; + return Mono.error(new Throwable("Simulated error")); + } else { + return Mono.just(p); + } + } + }; - socket - .requestResponse(new PayloadImpl("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + RSocketServer.create(SocketAcceptor.with(rsocket)) + .bindNow(TcpServerTransport.create("localhost", 7000)); - socket - .requestResponse(new PayloadImpl("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + RSocket socket = + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); - socket - .requestResponse(new PayloadImpl("Hello")) - .map(Payload::getDataUtf8) - .onErrorReturn("error") - .doOnNext(System.out::println) - .block(); + for (int i = 0; i < 3; i++) { + socket + .requestResponse(DefaultPayload.create("Hello")) + .map(Payload::getDataUtf8) + .onErrorReturn("error") + .doOnNext(logger::debug) + .block(); + } - socket.close().block(); + socket.dispose(); } } diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java new file mode 100644 index 000000000..6724ca93f --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java @@ -0,0 +1,141 @@ +package io.rsocket.examples.transport.tcp.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import java.io.BufferedInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.SynchronousSink; + +class Files { + private static final Logger logger = LoggerFactory.getLogger(Files.class); + + public static Flux fileSource(String fileName, int chunkSizeBytes) { + return Flux.generate( + () -> new FileState(fileName, chunkSizeBytes), FileState::consumeNext, FileState::dispose); + } + + public static Subscriber fileSink(String fileName, int windowSize) { + return new Subscriber() { + Subscription s; + int requests = windowSize; + OutputStream outputStream; + int receivedBytes; + int receivedCount; + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + this.s.request(requests); + } + + @Override + public void onNext(Payload payload) { + ByteBuf data = payload.data(); + receivedBytes += data.readableBytes(); + receivedCount += 1; + logger.debug("Received file chunk: " + receivedCount + ". Total size: " + receivedBytes); + if (outputStream == null) { + outputStream = open(fileName); + } + write(outputStream, data); + payload.release(); + + requests--; + if (requests == windowSize / 2) { + requests += windowSize; + s.request(windowSize); + } + } + + private void write(OutputStream outputStream, ByteBuf byteBuf) { + try { + byteBuf.readBytes(outputStream, byteBuf.readableBytes()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable t) { + close(outputStream); + } + + @Override + public void onComplete() { + close(outputStream); + } + + private OutputStream open(String filename) { + try { + /*do not buffer for demo purposes*/ + return new FileOutputStream(filename); + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } + } + + private void close(OutputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + } + } + } + }; + } + + private static class FileState { + private final String fileName; + private final int chunkSizeBytes; + private BufferedInputStream inputStream; + private byte[] chunkBytes; + + public FileState(String fileName, int chunkSizeBytes) { + this.fileName = fileName; + this.chunkSizeBytes = chunkSizeBytes; + } + + public FileState consumeNext(SynchronousSink sink) { + if (inputStream == null) { + InputStream in = getClass().getClassLoader().getResourceAsStream(fileName); + if (in == null) { + sink.error(new FileNotFoundException(fileName)); + return this; + } + this.inputStream = new BufferedInputStream(in); + this.chunkBytes = new byte[chunkSizeBytes]; + } + try { + int consumedBytes = inputStream.read(chunkBytes); + if (consumedBytes == -1) { + sink.complete(); + } else { + sink.next(Unpooled.copiedBuffer(chunkBytes, 0, consumedBytes)); + } + } catch (IOException e) { + sink.error(e); + } + return this; + } + + public void dispose() { + if (inputStream != null) { + try { + inputStream.close(); + } catch (IOException e) { + } + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java new file mode 100644 index 000000000..ba82c7c93 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.resume; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.util.retry.Retry; + +public class ResumeFileTransfer { + + /*amount of file chunks requested by subscriber: n, refilled on n/2 of received items*/ + private static final int PREFETCH_WINDOW_SIZE = 4; + private static final Logger logger = LoggerFactory.getLogger(ResumeFileTransfer.class); + + public static void main(String[] args) { + + Resume resume = + new Resume() + .sessionDuration(Duration.ofMinutes(5)) + .retry( + Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)) + .doBeforeRetry(s -> logger.debug("Disconnected. Trying to resume..."))); + + RequestCodec codec = new RequestCodec(); + + CloseableChannel server = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> { + Request request = codec.decode(payload); + payload.release(); + String fileName = request.getFileName(); + int chunkSize = request.getChunkSize(); + + Flux ticks = Flux.interval(Duration.ofMillis(500)).onBackpressureDrop(); + + return Files.fileSource(fileName, chunkSize) + .map(DefaultPayload::create) + .zipWith(ticks, (p, tick) -> p) + .log("server"); + })) + .resume(resume) + .bindNow(TcpServerTransport.create("localhost", 8000)); + + RSocket client = + RSocketConnector.create() + .resume(resume) + .connect(TcpClientTransport.create("localhost", 8001)) + .block(); + + client + .requestStream(codec.encode(new Request(16, "lorem.txt"))) + .log("client") + .doFinally(s -> server.dispose()) + .subscribe(Files.fileSink("rsocket-examples/build/lorem_output.txt", PREFETCH_WINDOW_SIZE)); + + server.onClose().block(); + } + + private static class RequestCodec { + + public Payload encode(Request request) { + String encoded = request.getChunkSize() + ":" + request.getFileName(); + return DefaultPayload.create(encoded); + } + + public Request decode(Payload payload) { + String encoded = payload.getDataUtf8(); + String[] chunkSizeAndFileName = encoded.split(":"); + int chunkSize = Integer.parseInt(chunkSizeAndFileName[0]); + String fileName = chunkSizeAndFileName[1]; + return new Request(chunkSize, fileName); + } + } + + private static class Request { + private final int chunkSize; + private final String fileName; + + public Request(int chunkSize, String fileName) { + this.chunkSize = chunkSize; + this.fileName = fileName; + } + + public int getChunkSize() { + return chunkSize; + } + + public String getFileName() { + return fileName; + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md new file mode 100644 index 000000000..55e761fe8 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md @@ -0,0 +1,29 @@ +1. Start socat. It is used for emulation of transport disconnects + +`socat -d TCP-LISTEN:8001,fork,reuseaddr TCP:localhost:8000` + +2. start `ResumeFileTransfer.main` + +3. terminate/start socat periodically for session resumption + +`ResumeFileTransfer` output is as follows + +``` +Received file chunk: 7. Total size: 112 +Received file chunk: 8. Total size: 128 +Received file chunk: 9. Total size: 144 +Received file chunk: 10. Total size: 160 +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Received file chunk: 11. Total size: 176 +Received file chunk: 12. Total size: 192 +Received file chunk: 13. Total size: 208 +Received file chunk: 14. Total size: 224 +Received file chunk: 15. Total size: 240 +Received file chunk: 16. Total size: 256 +``` + +It transfers file from `resources/lorem.txt` to `build/out/lorem_output.txt` in chunks of 16 bytes every 500 millis diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java new file mode 100644 index 000000000..af0df3be1 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.stream; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public final class ClientStreamingToServer { + + private static final Logger logger = LoggerFactory.getLogger(ClientStreamingToServer.class); + + public static void main(String[] args) throws InterruptedException { + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.interval(Duration.ofMillis(100)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + .setupPayload(DefaultPayload.create("test", "test")) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final Payload payload = DefaultPayload.create("Hello"); + socket + .requestStream(payload) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .doFinally(signalType -> socket.dispose()) + .then() + .block(); + + Thread.sleep(1000000); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java new file mode 100644 index 000000000..10ed34553 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java @@ -0,0 +1,60 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.stream; + +import static io.rsocket.SocketAcceptor.forRequestStream; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public final class ServerStreamingToClient { + + public static void main(String[] args) { + + RSocketServer.create( + (setup, rsocket) -> { + rsocket + .requestStream(DefaultPayload.create("Hello-Bidi")) + .map(Payload::getDataUtf8) + .log() + .subscribe(); + + return Mono.just(new RSocket() {}); + }) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket rsocket = + RSocketConnector.create() + .acceptor( + forRequestStream( + payload -> + Flux.interval(Duration.ofSeconds(1)) + .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + rsocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java deleted file mode 100644 index 0df129a67..000000000 --- a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/StreamingClient.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - *

- * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - *

- * http://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.rsocket.examples.transport.tcp.stream; - -import io.rsocket.*; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.util.PayloadImpl; -import java.time.Duration; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -public final class StreamingClient { - - public static void main(String[] args) { - RSocketFactory.receive() - .acceptor(new SocketAcceptorImpl()) - .transport(TcpServerTransport.create("localhost", 7000)) - .start() - .subscribe(); - - RSocket socket = - RSocketFactory.connect() - .transport(TcpClientTransport.create("localhost", 7000)) - .start() - .block(); - - socket - .requestStream(new PayloadImpl("Hello")) - .map(Payload::getDataUtf8) - .doOnNext(System.out::println) - .take(10) - .thenEmpty(socket.close()) - .block(); - } - - private static class SocketAcceptorImpl implements SocketAcceptor { - @Override - public Mono accept(ConnectionSetupPayload setupPayload, RSocket reactiveSocket) { - return Mono.just( - new AbstractRSocket() { - @Override - public Flux requestStream(Payload payload) { - return Flux.interval(Duration.ofMillis(100)) - .map(aLong -> new PayloadImpl("Interval: " + aLong)); - } - }); - } - } -} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java new file mode 100644 index 000000000..89304853c --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketAggregationSample { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketAggregationSample.class); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor connectionAcceptor = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .port(0) + .handle( + (req, res) -> + res.sendWebsocket( + (in, out) -> + connectionAcceptor + .apply( + new WebsocketDuplexConnection( + (Connection) in.aggregateFrames())) + .then(out.neverComplete()))) + .bindNow(); + + WebsocketClientTransport transport = + WebsocketClientTransport.create(server.host(), server.port()); + + RSocket clientRSocket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(transport) + .block(); + + Flux.range(1, 100) + .concatMap(i -> clientRSocket.requestResponse(ByteBufPayload.create("Hello " + i))) + .doOnNext(payload -> logger.debug("Processed " + payload.getDataUtf8())) + .blockLast(); + clientRSocket.dispose(); + server.dispose(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java new file mode 100644 index 000000000..72e003d2a --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java @@ -0,0 +1,99 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.netty.handler.codec.http.HttpResponseStatus; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketHeadersSample { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketHeadersSample.class); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor connectionAcceptor = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .port(0) + .route( + routes -> + routes.get( + "/", + (req, res) -> { + if (req.requestHeaders().containsValue("Authorization", "test", true)) { + return res.sendWebsocket( + (in, out) -> + connectionAcceptor + .apply(new WebsocketDuplexConnection((Connection) in)) + .then(out.neverComplete())); + } + res.status(HttpResponseStatus.UNAUTHORIZED); + return res.send(); + })) + .bindNow(); + + logger.debug( + "\n\nStart of Authorized WebSocket Connection\n----------------------------------\n"); + + WebsocketClientTransport transport = + WebsocketClientTransport.create(server.host(), server.port()) + .header("Authorization", "test"); + + RSocket clientRSocket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(transport) + .block(); + + Flux.range(1, 100) + .concatMap(i -> clientRSocket.requestResponse(ByteBufPayload.create("Hello " + i))) + .doOnNext(payload -> logger.debug("Processed " + payload.getDataUtf8())) + .blockLast(); + clientRSocket.dispose(); + + logger.debug( + "\n\nStart of Unauthorized WebSocket Upgrade\n----------------------------------\n"); + + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(WebsocketClientTransport.create(server.host(), server.port())) + .block(); + } +} diff --git a/rsocket-examples/src/main/resources/logback.xml b/rsocket-examples/src/main/resources/logback.xml new file mode 100644 index 000000000..780a70c99 --- /dev/null +++ b/rsocket-examples/src/main/resources/logback.xml @@ -0,0 +1,34 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + + diff --git a/rsocket-examples/src/main/resources/lorem.txt b/rsocket-examples/src/main/resources/lorem.txt new file mode 100644 index 000000000..e035ea86d --- /dev/null +++ b/rsocket-examples/src/main/resources/lorem.txt @@ -0,0 +1,32 @@ +Alteration literature to or an sympathize mr imprudence. Of is ferrars subject as enjoyed or tedious cottage. +Procuring as in resembled by in agreeable. Next long no gave mr eyes. Admiration advantages no he celebrated so pianoforte unreserved. +Not its herself forming charmed amiable. Him why feebly expect future now. + +Situation admitting promotion at or to perceived be. Mr acuteness we as estimable enjoyment up. +An held late as felt know. Learn do allow solid to grave. Middleton suspicion age her attention. +Chiefly several bed its wishing. Is so moments on chamber pressed to. Doubtful yet way properly answered humanity its desirous. + Minuter believe service arrived civilly add all. Acuteness allowance an at eagerness favourite in extensive exquisite ye. + + Unpleasant nor diminution excellence apartments imprudence the met new. Draw part them he an to he roof only. + Music leave say doors him. Tore bred form if sigh case as do. Staying he no looking if do opinion. + Sentiments way understood end partiality and his. + + Ladyship it daughter securing procured or am moreover mr. Put sir she exercise vicinity cheerful wondered. + Continual say suspicion provision you neglected sir curiosity unwilling. Simplicity end themselves increasing led day sympathize yet. + General windows effects not are drawing man garrets. Common indeed garden you his ladies out yet. Preference imprudence contrasted to remarkably in on. + Taken now you him trees tears any. Her object giving end sister except oppose. + + No comfort do written conduct at prevent manners on. Celebrated contrasted discretion him sympathize her collecting occasional. + Do answered bachelor occasion in of offended no concerns. Supply worthy warmth branch of no ye. Voice tried known to as my to. + Though wished merits or be. Alone visit use these smart rooms ham. No waiting in on enjoyed placing it inquiry. + + So insisted received is occasion advanced honoured. Among ready to which up. Attacks smiling and may out assured moments man nothing outward. + Thrown any behind afford either the set depend one temper. Instrument melancholy in acceptance collecting frequently be if. + Zealously now pronounce existence add you instantly say offending. Merry their far had widen was. Concerns no in expenses raillery formerly. + + As am hastily invited settled at limited civilly fortune me. Really spring in extent an by. Judge but built gay party world. + Of so am he remember although required. Bachelor unpacked be advanced at. Confined in declared marianne is vicinity. + + In alteration insipidity impression by travelling reasonable up motionless. Of regard warmth by unable sudden garden ladies. + No kept hung am size spot no. Likewise led and dissuade rejoiced welcomed husbands boy. Do listening on he suspected resembled. + Water would still if to. Position boy required law moderate was may. \ No newline at end of file diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java index 34ed123ad..ac311a231 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,88 +16,106 @@ package io.rsocket.integration; -import static org.hamcrest.Matchers.is; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.plugins.DuplexConnectionInterceptor; import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.plugins.SocketAcceptorInterceptor; import io.rsocket.test.TestSubscriber; import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.NettyContextCloseable; +import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.DefaultPayload; import io.rsocket.util.RSocketProxy; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class IntegrationTest { - private NettyContextCloseable server; - private RSocket client; - private AtomicInteger requestCount; - private CountDownLatch disconnectionCounter; - public static volatile boolean calledClient = false; - public static volatile boolean calledServer = false; - public static volatile boolean calledFrame = false; + private static final RSocketInterceptor requesterInterceptor; + private static final RSocketInterceptor responderInterceptor; + private static final SocketAcceptorInterceptor clientAcceptorInterceptor; + private static final SocketAcceptorInterceptor serverAcceptorInterceptor; + private static final DuplexConnectionInterceptor connectionInterceptor; - private static final RSocketInterceptor clientPlugin; - private static final RSocketInterceptor serverPlugin; - private static final DuplexConnectionInterceptor connectionPlugin; + private static volatile boolean calledRequester = false; + private static volatile boolean calledResponder = false; + private static volatile boolean calledClientAcceptor = false; + private static volatile boolean calledServerAcceptor = false; + private static volatile boolean calledFrame = false; static { - clientPlugin = + requesterInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override public Mono requestResponse(Payload payload) { - calledClient = true; + calledRequester = true; return reactiveSocket.requestResponse(payload); } }; - serverPlugin = + responderInterceptor = reactiveSocket -> new RSocketProxy(reactiveSocket) { @Override public Mono requestResponse(Payload payload) { - calledServer = true; + calledResponder = true; return reactiveSocket.requestResponse(payload); } }; - connectionPlugin = + clientAcceptorInterceptor = + acceptor -> + (setup, sendingSocket) -> { + calledClientAcceptor = true; + return acceptor.accept(setup, sendingSocket); + }; + + serverAcceptorInterceptor = + acceptor -> + (setup, sendingSocket) -> { + calledServerAcceptor = true; + return acceptor.accept(setup, sendingSocket); + }; + + connectionInterceptor = (type, connection) -> { calledFrame = true; return connection; }; } - @Before + private CloseableChannel server; + private RSocket client; + private AtomicInteger requestCount; + private CountDownLatch disconnectionCounter; + private AtomicInteger errorCount; + + @BeforeEach public void startup() { + errorCount = new AtomicInteger(); requestCount = new AtomicInteger(); disconnectionCounter = new CountDownLatch(1); - TcpServerTransport serverTransport = TcpServerTransport.create(0); - server = - RSocketFactory.receive() - .addServerPlugin(serverPlugin) - .addConnectionPlugin(connectionPlugin) - .acceptor( + RSocketServer.create( (setup, sendingSocket) -> { sendingSocket .onClose() @@ -105,58 +123,86 @@ public void startup() { .subscribe(); return Mono.just( - new AbstractRSocket() { + new RSocket() { @Override public Mono requestResponse(Payload payload) { - return Mono.just(new PayloadImpl("RESPONSE", "METADATA")) + return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")) .doOnSubscribe(s -> requestCount.incrementAndGet()); } @Override public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000).map(i -> new PayloadImpl("data -> " + i)); + return Flux.range(1, 10_000) + .map(i -> DefaultPayload.create("data -> " + i)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); } }); }) - .transport(serverTransport) - .start() + .interceptors( + registry -> + registry + .forResponder(responderInterceptor) + .forSocketAcceptor(serverAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) .block(); client = - RSocketFactory.connect() - .addClientPlugin(clientPlugin) - .addConnectionPlugin(connectionPlugin) - .transport(TcpClientTransport.create(server.address())) - .start() + RSocketConnector.create() + .interceptors( + registry -> + registry + .forRequester(requesterInterceptor) + .forSocketAcceptor(clientAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .connect(TcpClientTransport.create(server.address())) .block(); } - @After + @AfterEach public void teardown() { - server.close().block(); + server.dispose(); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testRequest() { - client.requestResponse(new PayloadImpl("REQUEST", "META")).block(); - assertThat("Server did not see the request.", requestCount.get(), is(1)); - assertTrue(calledClient); - assertTrue(calledServer); - assertTrue(calledFrame); + client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); + assertThat(requestCount).as("Server did not see the request.").hasValue(1); + + assertThat(calledRequester).isTrue(); + assertThat(calledResponder).isTrue(); + assertThat(calledClientAcceptor).isTrue(); + assertThat(calledServerAcceptor).isTrue(); + assertThat(calledFrame).isTrue(); } @Test + @Timeout(5_000L) public void testStream() { Subscriber subscriber = TestSubscriber.createCancelling(); - client.requestStream(new PayloadImpl("start")).subscribe(subscriber); + client.requestStream(DefaultPayload.create("start")).subscribe(subscriber); verify(subscriber).onSubscribe(any()); verifyNoMoreInteractions(subscriber); } - @Test(timeout = 5_000L) + @Test + @Timeout(5_000L) public void testClose() throws InterruptedException { - client.close().block(); + client.dispose(); disconnectionCounter.await(); } + + @Test // (timeout = 5_000L) + public void testCallRequestWithErrorAndThenRequest() { + assertThatThrownBy(client.requestChannel(Mono.error(new Throwable("test")))::blockLast) + .hasMessage("java.lang.Throwable: test"); + + testRequest(); + } } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java new file mode 100644 index 000000000..48e5baaa7 --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -0,0 +1,104 @@ +package io.rsocket.integration; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.test.SlowTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.function.Supplier; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +public class InteractionsLoadTest { + + @Test + @SlowTest + public void channel() { + CloseableChannel server = + RSocketServer.create(SocketAcceptor.with(new EchoRSocket())) + .bind(TcpServerTransport.create("localhost", 0)) + .block(Duration.ofSeconds(10)); + + RSocket clientRSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .block(Duration.ofSeconds(10)); + + int concurrency = 16; + Flux.range(1, concurrency) + .flatMap( + v -> + clientRSocket + .requestChannel( + input().onBackpressureDrop().map(iv -> DefaultPayload.create("foo"))) + .limitRate(10000), + concurrency) + .timeout(Duration.ofSeconds(5)) + .doOnNext( + p -> { + String data = p.getDataUtf8(); + if (!data.equals("bar")) { + throw new IllegalStateException("Channel Client Bad message: " + data); + } + }) + .window(Duration.ofSeconds(1)) + .flatMap(Flux::count) + .doOnNext(d -> System.out.println("Got: " + d)) + .take(Duration.ofMinutes(1)) + .doOnTerminate(server::dispose) + .subscribe(); + + server.onClose().block(); + } + + private static Flux input() { + Flux interval = Flux.interval(Duration.ofMillis(1)).onBackpressureDrop(); + for (int i = 0; i < 10; i++) { + interval = interval.mergeWith(interval); + } + return interval; + } + + private static class EchoRSocket implements RSocket { + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .map( + p -> { + String data = p.getDataUtf8(); + if (!data.equals("foo")) { + throw new IllegalStateException("Channel Server Bad message: " + data); + } + return DefaultPayload.create("bar"); + }); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload) + .map( + p -> { + String data = p.getDataUtf8(); + return data; + }) + .doOnNext( + (data) -> { + if (!data.equals("foo")) { + throw new IllegalStateException("Stream Server Bad message: " + data); + } + }) + .flatMap( + data -> { + Supplier p = () -> DefaultPayload.create("bar"); + return Flux.range(1, 100).map(v -> p.get()); + }); + } + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java index 3a0067c12..1924668fb 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -1,75 +1,71 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2021 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.integration; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; +import static org.assertj.core.api.Assertions.assertThat; -import io.rsocket.AbstractRSocket; import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.NettyContextCloseable; +import io.rsocket.transport.netty.server.CloseableChannel; import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; import io.rsocket.util.RSocketProxy; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; import reactor.core.scheduler.Schedulers; public class TcpIntegrationTest { - private AbstractRSocket handler; + private RSocket handler; - private NettyContextCloseable server; + private CloseableChannel server; - @Before + @BeforeEach public void startup() { - TcpServerTransport serverTransport = TcpServerTransport.create(0); server = - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) - .transport(serverTransport) - .start() + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .bind(TcpServerTransport.create("localhost", 0)) .block(); } private RSocket buildClient() { - return RSocketFactory.connect() - .transport(TcpClientTransport.create(server.address())) - .start() - .block(); + return RSocketConnector.connectWith(TcpClientTransport.create(server.address())).block(); } - @After + @AfterEach public void cleanup() { - server.close().block(); + server.dispose(); } - @Test(timeout = 5_000L) + @Test + @Timeout(15_000L) public void testCompleteWithoutNext() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { return Flux.empty(); @@ -77,49 +73,52 @@ public Flux requestStream(Payload payload) { }; RSocket client = buildClient(); Boolean hasElements = - client.requestStream(new PayloadImpl("REQUEST", "META")).log().hasElements().block(); + client.requestStream(DefaultPayload.create("REQUEST", "META")).log().hasElements().block(); - assertFalse(hasElements); + assertThat(hasElements).isFalse(); } - @Test(timeout = 5_000L) + @Test + @Timeout(15_000L) public void testSingleStream() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { - return Flux.just(new PayloadImpl("RESPONSE", "METADATA")); + return Flux.just(DefaultPayload.create("RESPONSE", "METADATA")); } }; RSocket client = buildClient(); - Payload result = client.requestStream(new PayloadImpl("REQUEST", "META")).blockLast(); + Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockLast(); - assertEquals("RESPONSE", result.getDataUtf8()); + assertThat(result.getDataUtf8()).isEqualTo("RESPONSE"); } - @Test(timeout = 5_000L) + @Test + @Timeout(15_000L) public void testZeroPayload() { handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { - return Flux.just(PayloadImpl.EMPTY); + return Flux.just(EmptyPayload.INSTANCE); } }; RSocket client = buildClient(); - Payload result = client.requestStream(new PayloadImpl("REQUEST", "META")).blockFirst(); + Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockFirst(); - assertEquals("", result.getDataUtf8()); + assertThat(result.getDataUtf8()).isEmpty(); } - @Test(timeout = 5_000L) + @Test + @Timeout(15_000L) public void testRequestResponseErrors() { handler = - new AbstractRSocket() { + new RSocket() { boolean first = true; @Override @@ -128,7 +127,7 @@ public Mono requestResponse(Payload payload) { first = false; return Mono.error(new RuntimeException("EX")); } else { - return Mono.just(new PayloadImpl("SUCCESS")); + return Mono.just(DefaultPayload.create("SUCCESS")); } } }; @@ -137,39 +136,40 @@ public Mono requestResponse(Payload payload) { Payload response1 = client - .requestResponse(new PayloadImpl("REQUEST", "META")) - .onErrorReturn(new PayloadImpl("ERROR")) + .requestResponse(DefaultPayload.create("REQUEST", "META")) + .onErrorReturn(DefaultPayload.create("ERROR")) .block(); Payload response2 = client - .requestResponse(new PayloadImpl("REQUEST", "META")) - .onErrorReturn(new PayloadImpl("ERROR")) + .requestResponse(DefaultPayload.create("REQUEST", "META")) + .onErrorReturn(DefaultPayload.create("ERROR")) .block(); - assertEquals("ERROR", response1.getDataUtf8()); - assertEquals("SUCCESS", response2.getDataUtf8()); + assertThat(response1.getDataUtf8()).isEqualTo("ERROR"); + assertThat(response2.getDataUtf8()).isEqualTo("SUCCESS"); } - @Test(timeout = 5_000L) + @Test + @Timeout(15_000L) public void testTwoConcurrentStreams() throws InterruptedException { - ConcurrentHashMap> map = new ConcurrentHashMap<>(); - UnicastProcessor processor1 = UnicastProcessor.create(); + ConcurrentHashMap> map = new ConcurrentHashMap<>(); + Sinks.Many processor1 = Sinks.many().unicast().onBackpressureBuffer(); map.put("REQUEST1", processor1); - UnicastProcessor processor2 = UnicastProcessor.create(); + Sinks.Many processor2 = Sinks.many().unicast().onBackpressureBuffer(); map.put("REQUEST2", processor2); handler = - new AbstractRSocket() { + new RSocket() { @Override public Flux requestStream(Payload payload) { - return map.get(payload.getDataUtf8()); + return map.get(payload.getDataUtf8()).asFlux(); } }; RSocket client = buildClient(); - Flux response1 = client.requestStream(new PayloadImpl("REQUEST1")); - Flux response2 = client.requestStream(new PayloadImpl("REQUEST2")); + Flux response1 = client.requestStream(DefaultPayload.create("REQUEST1")); + Flux response2 = client.requestStream(DefaultPayload.create("REQUEST2")); CountDownLatch nextCountdown = new CountDownLatch(2); CountDownLatch completeCountdown = new CountDownLatch(2); @@ -182,13 +182,13 @@ public Flux requestStream(Payload payload) { .subscribeOn(Schedulers.newSingle("2")) .subscribe(c -> nextCountdown.countDown(), t -> {}, completeCountdown::countDown); - processor1.onNext(new PayloadImpl("RESPONSE1A")); - processor2.onNext(new PayloadImpl("RESPONSE2A")); + processor1.tryEmitNext(DefaultPayload.create("RESPONSE1A")); + processor2.tryEmitNext(DefaultPayload.create("RESPONSE2A")); nextCountdown.await(); - processor1.onComplete(); - processor2.onComplete(); + processor1.tryEmitComplete(); + processor2.tryEmitComplete(); completeCountdown.await(); } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java index c41af77d2..cd96584ed 100644 --- a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java @@ -1,66 +1,65 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.integration; -import io.rsocket.*; -import io.rsocket.exceptions.ApplicationException; -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.exceptions.ApplicationErrorException; import io.rsocket.transport.local.LocalClientTransport; import io.rsocket.transport.local.LocalServerTransport; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.DefaultPayload; import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Supplier; -import org.junit.Test; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; public class TestingStreaming { - private Supplier> serverSupplier = - () -> LocalServerTransport.create("test"); - - private Supplier clientSupplier = () -> LocalClientTransport.create("test"); + LocalServerTransport serverTransport = LocalServerTransport.create("test"); - @Test(expected = ApplicationException.class) + @Test public void testRangeButThrowException() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .doOnNext( - i -> { - if (i > 3) { - throw new RuntimeException("BOOM!"); - } - }) - .map(l -> new PayloadImpl("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .doOnNext( + i -> { + if (i > 3) { + throw new RuntimeException("BOOM!"); + } + }) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); - Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); - System.out.println("here"); + Assertions.assertThatThrownBy( + Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i))::blockLast) + .isInstanceOf(ApplicationErrorException.class); } finally { - server.close().block(); + server.dispose(); } } @@ -69,86 +68,50 @@ public void testRangeOfConsumers() { Closeable server = null; try { server = - RSocketFactory.receive() - .errorConsumer(Throwable::printStackTrace) - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 1000) - .map(l -> new PayloadImpl("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); - System.out.println("here"); - } finally { - server.close().block(); + server.dispose(); } } private Flux consumer(String s) { - return RSocketFactory.connect() - .errorConsumer(Throwable::printStackTrace) - .transport(clientSupplier) - .start() + return RSocketConnector.connectWith(LocalClientTransport.create("test")) .flatMapMany( rSocket -> { AtomicInteger count = new AtomicInteger(); return Flux.range(1, 100) - .flatMap(i -> rSocket.requestStream(new PayloadImpl("i -> " + i)).take(100), 1); + .flatMap( + i -> rSocket.requestStream(DefaultPayload.create("i -> " + i)).take(100), 1); }); } @Test public void testSingleConsumer() { Closeable server = null; - try { server = - RSocketFactory.receive() - .acceptor( - (connectionSetupPayload, rSocket) -> { - AbstractRSocket abstractRSocket = - new AbstractRSocket() { - @Override - public double availability() { - return 1.0; - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000) - .map(l -> new PayloadImpl("l -> " + l)) - .cast(Payload.class); - } - }; - - return Mono.just(abstractRSocket); - }) - .transport(serverSupplier.get()) - .start() + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 10_000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) .block(); consumer("1").blockLast(); } finally { - server.close().block(); + server.dispose(); } } diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java new file mode 100644 index 000000000..870ecf0cd --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java @@ -0,0 +1,246 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration.observation; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.micrometer.core.tck.MeterRegistryAssert; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.tracing.test.SampleTestRunner; +import io.micrometer.tracing.test.reporter.BuildingBlocks; +import io.micrometer.tracing.test.simple.SpansAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.micrometer.observation.ByteBufGetter; +import io.rsocket.micrometer.observation.ByteBufSetter; +import io.rsocket.micrometer.observation.ObservationRequesterRSocketProxy; +import io.rsocket.micrometer.observation.ObservationResponderRSocketProxy; +import io.rsocket.micrometer.observation.RSocketRequesterTracingObservationHandler; +import io.rsocket.micrometer.observation.RSocketResponderTracingObservationHandler; +import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.Deque; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class ObservationIntegrationTest extends SampleTestRunner { + private static final MeterRegistry registry = new SimpleMeterRegistry(); + private static final ObservationRegistry observationRegistry = ObservationRegistry.create(); + + static { + observationRegistry + .observationConfig() + .observationHandler(new DefaultMeterObservationHandler(registry)); + } + + private final RSocketInterceptor requesterInterceptor; + private final RSocketInterceptor responderInterceptor; + + ObservationIntegrationTest() { + super(SampleRunnerConfig.builder().build()); + requesterInterceptor = + reactiveSocket -> new ObservationRequesterRSocketProxy(reactiveSocket, observationRegistry); + + responderInterceptor = + reactiveSocket -> new ObservationResponderRSocketProxy(reactiveSocket, observationRegistry); + } + + private CloseableChannel server; + private RSocket client; + private AtomicInteger counter; + + @Override + public BiConsumer>> + customizeObservationHandlers() { + return (buildingBlocks, observationHandlers) -> { + observationHandlers.addFirst( + new RSocketRequesterTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufSetter(), + false)); + observationHandlers.addFirst( + new RSocketResponderTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufGetter(), + false)); + }; + } + + @AfterEach + public void teardown() { + if (server != null) { + server.dispose(); + } + } + + private void testRequest() { + counter.set(0); + client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testStream() { + counter.set(0); + client.requestStream(DefaultPayload.create("start")).blockLast(); + + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testRequestChannel() { + counter.set(0); + client.requestChannel(Mono.just(DefaultPayload.create("start"))).blockFirst(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testFireAndForget() { + counter.set(0); + client.fireAndForget(DefaultPayload.create("start")).subscribe(); + Awaitility.await().atMost(Duration.ofSeconds(50)).until(() -> counter.get() == 1); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + @Override + public SampleTestRunnerConsumer yourCode() { + return (bb, meterRegistry) -> { + counter = new AtomicInteger(); + server = + RSocketServer.create( + (setup, sendingSocket) -> { + sendingSocket.onClose().subscribe(); + + return Mono.just( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")); + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Flux.range(1, 10_000) + .map(i -> DefaultPayload.create("data -> " + i)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + counter.incrementAndGet(); + return Flux.from(payloads); + } + + @Override + public Mono fireAndForget(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.empty(); + } + }); + }) + .interceptors(registry -> registry.forResponder(responderInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + + client = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(requesterInterceptor)) + .connect(TcpClientTransport.create(server.address())) + .block(); + + testRequest(); + + testStream(); + + testRequestChannel(); + + testFireAndForget(); + + // @formatter:off + SpansAssert.assertThat(bb.getFinishedSpans()) + .haveSameTraceId() + // "request_*" + "handle" x 4 + .hasNumberOfSpansEqualTo(8) + .hasNumberOfSpansWithNameEqualTo("handle", 4) + .forAllSpansWithNameEqualTo("handle", span -> span.hasTagWithKey("rsocket.request-type")) + .hasASpanWithNameIgnoreCase("request_stream") + .thenASpanWithNameEqualToIgnoreCase("request_stream") + .hasTag("rsocket.request-type", "REQUEST_STREAM") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_channel") + .thenASpanWithNameEqualToIgnoreCase("request_channel") + .hasTag("rsocket.request-type", "REQUEST_CHANNEL") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_fnf") + .thenASpanWithNameEqualToIgnoreCase("request_fnf") + .hasTag("rsocket.request-type", "REQUEST_FNF") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_response") + .thenASpanWithNameEqualToIgnoreCase("request_response") + .hasTag("rsocket.request-type", "REQUEST_RESPONSE"); + + MeterRegistryAssert.assertThat(registry) + .hasTimerWithNameAndTags( + "rsocket.response", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.fnf", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_FNF"))) + .hasTimerWithNameAndTags( + "rsocket.request", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.channel", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_CHANNEL"))) + .hasTimerWithNameAndTags( + "rsocket.stream", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_STREAM"))); + // @formatter:on + }; + } + + @Override + protected MeterRegistry getMeterRegistry() { + return registry; + } + + @Override + protected ObservationRegistry getObservationRegistry() { + return observationRegistry; + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java b/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java new file mode 100644 index 000000000..5824918bc --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java @@ -0,0 +1,75 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.DuplexConnection; +import io.rsocket.transport.ClientTransport; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import reactor.core.publisher.Mono; + +class DisconnectableClientTransport implements ClientTransport { + private final ClientTransport clientTransport; + private final AtomicReference curConnection = new AtomicReference<>(); + private long nextConnectPermitMillis; + + public DisconnectableClientTransport(ClientTransport clientTransport) { + this.clientTransport = clientTransport; + } + + @Override + public Mono connect() { + return Mono.defer( + () -> + now() < nextConnectPermitMillis + ? Mono.error(new ClosedChannelException()) + : clientTransport + .connect() + .map( + c -> { + if (curConnection.compareAndSet(null, c)) { + return c; + } else { + throw new IllegalStateException( + "Transport supports at most 1 connection"); + } + })); + } + + public void disconnect() { + disconnectFor(Duration.ZERO); + } + + public void disconnectPermanently() { + disconnectFor(Duration.ofDays(42)); + } + + public void disconnectFor(Duration cooldown) { + DuplexConnection cur = curConnection.getAndSet(null); + if (cur != null) { + nextConnectPermitMillis = now() + cooldown.toMillis(); + cur.dispose(); + } else { + throw new IllegalStateException("Trying to disconnect while not connected"); + } + } + + private static long now() { + return System.currentTimeMillis(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java new file mode 100644 index 000000000..5eb78fabe --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java @@ -0,0 +1,229 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.exceptions.UnsupportedSetupException; +import io.rsocket.test.SlowTest; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; + +@SlowTest +public class ResumeIntegrationTest { + private static final String SERVER_HOST = "localhost"; + private static final int SERVER_PORT = 0; + + @Test + void timeoutOnPermanentDisconnect() { + CloseableChannel closeable = newServerRSocket().block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + + int sessionDurationSeconds = 5; + RSocket rSocket = newClientRSocket(clientTransport, sessionDurationSeconds).block(); + + Mono.delay(Duration.ofSeconds(1)).subscribe(v -> clientTransport.disconnectPermanently()); + + StepVerifier.create( + rSocket.requestChannel(testRequest()).then().doFinally(s -> closeable.dispose())) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(7)); + } + + @Test + public void reconnectOnDisconnect() { + CloseableChannel closeable = newServerRSocket().block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + + int sessionDurationSeconds = 15; + RSocket rSocket = newClientRSocket(clientTransport, sessionDurationSeconds).block(); + + Flux.just(3, 20, 40, 75) + .flatMap(v -> Mono.delay(Duration.ofSeconds(v))) + .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(7))); + + AtomicInteger counter = new AtomicInteger(-1); + StepVerifier.create( + rSocket + .requestChannel(testRequest()) + .take(Duration.ofSeconds(600)) + .map(Payload::getDataUtf8) + .timeout(Duration.ofSeconds(12)) + .doOnNext(x -> throwOnNonContinuous(counter, x)) + .then() + .doFinally(s -> closeable.dispose())) + .expectComplete() + .verify(); + } + + @Test + public void reconnectOnMissingSession() { + + int serverSessionDuration = 2; + + CloseableChannel closeable = newServerRSocket(serverSessionDuration).block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + int clientSessionDurationSeconds = 10; + + RSocket rSocket = newClientRSocket(clientTransport, clientSessionDurationSeconds).block(); + + Mono.delay(Duration.ofSeconds(1)) + .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(3))); + + StepVerifier.create( + rSocket.requestChannel(testRequest()).then().doFinally(s -> closeable.dispose())) + .expectError() + .verify(Duration.ofSeconds(5)); + + StepVerifier.create(rSocket.onClose()) + .expectErrorMatches( + err -> + err instanceof RejectedResumeException + && "unknown resume token".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + } + + @Test + void serverMissingResume() { + CloseableChannel closeableChannel = + RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)) + .block(); + + RSocket rSocket = + RSocketConnector.create() + .resume(new Resume()) + .connect(clientTransport(closeableChannel.address())) + .block(); + + StepVerifier.create(rSocket.onClose().doFinally(s -> closeableChannel.dispose())) + .expectErrorMatches( + err -> + err instanceof UnsupportedSetupException + && "resume not supported".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + } + + static ClientTransport clientTransport(InetSocketAddress address) { + return TcpClientTransport.create(address); + } + + static ServerTransport serverTransport(String host, int port) { + return TcpServerTransport.create(host, port); + } + + private static Flux testRequest() { + return Flux.interval(Duration.ofMillis(500)) + .map(v -> DefaultPayload.create("client_request")) + .onBackpressureDrop(); + } + + private void throwOnNonContinuous(AtomicInteger counter, String x) { + int curValue = Integer.parseInt(x); + int prevValue = counter.get(); + if (prevValue >= 0) { + int dif = curValue - prevValue; + if (dif != 1) { + throw new IllegalStateException( + String.format( + "Payload values are expected to be continuous numbers: %d %d", + prevValue, curValue)); + } + } + counter.set(curValue); + } + + private static Mono newClientRSocket( + DisconnectableClientTransport clientTransport, int sessionDurationSeconds) { + return RSocketConnector.create() + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .storeFactory(t -> new InMemoryResumableFramesStore("client", t, 500_000)) + .cleanupStoreOnKeepAlive() + .retry(Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)))) + .keepAlive(Duration.ofSeconds(5), Duration.ofMinutes(5)) + .connect(clientTransport); + } + + private static Mono newServerRSocket() { + return newServerRSocket(15); + } + + private static Mono newServerRSocket(int sessionDurationSeconds) { + return RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .cleanupStoreOnKeepAlive() + .storeFactory(t -> new InMemoryResumableFramesStore("server", t, 500_000))) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)); + } + + private static class TestResponderRSocket implements RSocket { + + AtomicInteger counter = new AtomicInteger(); + + @Override + public Flux requestChannel(Publisher payloads) { + return duplicate( + Flux.interval(Duration.ofMillis(1)) + .onBackpressureLatest() + .publishOn(Schedulers.boundedElastic()), + 20) + .map(v -> DefaultPayload.create(String.valueOf(counter.getAndIncrement()))) + .takeUntilOther(Flux.from(payloads).then()); + } + + private Flux duplicate(Flux f, int n) { + Flux r = Flux.empty(); + for (int i = 0; i < n; i++) { + r = r.mergeWith(f); + } + return r; + } + } +} diff --git a/rsocket-examples/src/test/resources/log4j.properties b/rsocket-examples/src/test/resources/log4j.properties deleted file mode 100644 index 3700f1f6e..000000000 --- a/rsocket-examples/src/test/resources/log4j.properties +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -#

-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -#

-# http://www.apache.org/licenses/LICENSE-2.0 -#

-# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss,SSS} %5p [%t] (%F) - %m%n -#log4j.logger.io.rsocket.FrameLogger=Debug \ No newline at end of file diff --git a/rsocket-examples/src/test/resources/logback-test.xml b/rsocket-examples/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-examples/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-load-balancer/build.gradle b/rsocket-load-balancer/build.gradle index 8ce69b411..6d91324ae 100644 --- a/rsocket-load-balancer/build.gradle +++ b/rsocket-load-balancer/build.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,7 +14,26 @@ * limitations under the License. */ +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + dependencies { - compile project(':rsocket-core') - testCompile project(':rsocket-test') + api project(':rsocket-core') + + implementation 'org.slf4j:slf4j-api' + + testImplementation project(':rsocket-test') + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' + testImplementation 'org.mockito:mockito-core' + testImplementation 'org.assertj:assertj-core' + testImplementation 'io.projectreactor:reactor-test' + + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' + testRuntimeOnly 'ch.qos.logback:logback-classic' } + +description = 'Transparent Load Balancer for RSocket' diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java index 76dd15ad2..6329da826 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,33 +13,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.client; -import static io.rsocket.util.ExceptionUtil.noStacktrace; +package io.rsocket.client; -import io.rsocket.Availability; -import io.rsocket.Closeable; -import io.rsocket.Payload; -import io.rsocket.RSocket; +import io.rsocket.*; import io.rsocket.client.filter.RSocketSupplier; -import io.rsocket.exceptions.NoAvailableRSocketException; -import io.rsocket.exceptions.TimeoutException; -import io.rsocket.exceptions.TransportException; import io.rsocket.stat.Ewma; import io.rsocket.stat.FrugalQuantile; import io.rsocket.stat.Median; import io.rsocket.stat.Quantile; import io.rsocket.util.Clock; -import io.rsocket.util.RSocketProxy; import java.nio.channels.ClosedChannelException; -import java.util.*; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Optional; +import java.util.Random; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,18 +42,21 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; import reactor.core.publisher.Operators; +import reactor.util.context.Context; +import reactor.util.retry.Retry; /** * An implementation of {@link Mono} that load balances across a pool of RSockets and emits one when * it is subscribed to * *

It estimates the load of each RSocket based on statistics collected. + * + * @deprecated as of 1.1. in favor of {@link io.rsocket.loadbalance.LoadbalanceRSocketClient}. */ +@Deprecated public abstract class LoadBalancedRSocketMono extends Mono implements Availability, Closeable { - private static final Logger logger = LoggerFactory.getLogger(LoadBalancedRSocketMono.class); - public static final double DEFAULT_EXP_FACTOR = 4.0; public static final double DEFAULT_LOWER_QUANTILE = 0.2; public static final double DEFAULT_HIGHER_QUANTILE = 0.8; @@ -69,39 +66,36 @@ public abstract class LoadBalancedRSocketMono extends Mono public static final int DEFAULT_MAX_APERTURE = 100; public static final long DEFAULT_MAX_REFRESH_PERIOD_MS = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES); - + private static final Logger logger = LoggerFactory.getLogger(LoadBalancedRSocketMono.class); private static final long APERTURE_REFRESH_PERIOD = Clock.unit().convert(15, TimeUnit.SECONDS); private static final int EFFORT = 5; private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = Clock.unit().convert(1L, TimeUnit.SECONDS); private static final int DEFAULT_INTER_ARRIVAL_FACTOR = 500; + private static final FailingRSocket FAILING_REACTIVE_SOCKET = new FailingRSocket(); + protected final Mono rSocketMono; private final double minPendings; private final double maxPendings; private final int minAperture; private final int maxAperture; private final long maxRefreshPeriod; - private final double expFactor; private final Quantile lowerQuantile; private final Quantile higherQuantile; - - private int pendingSockets; private final ArrayList activeSockets; - private final ArrayList activeFactories; - private final FactoriesRefresher factoryRefresher; - private final Mono selectSocket; - private final Ewma pendings; + private final MonoProcessor onClose = MonoProcessor.create(); + private final RSocketSupplierPool pool; + private final long weightedSocketRetries; + private final Duration weightedSocketBackOff; + private final Duration weightedSocketMaxBackOff; private volatile int targetAperture; private long lastApertureRefresh; private long refreshPeriod; + private int pendingSockets; private volatile long lastRefresh; - private final MonoProcessor onClose = MonoProcessor.create(); - protected final MonoProcessor started = MonoProcessor.create(); - protected final Mono rSocketMono; - /** * @param factories the source (factories) of RSocket * @param expFactor how aggressive is the algorithm toward outliers. A higher number means we send @@ -118,6 +112,11 @@ public abstract class LoadBalancedRSocketMono extends Mono * load. * @param maxRefreshPeriodMs the maximum time between two "refreshes" of the list of active * RSocket. This is at that time that the slowest RSocket is closed. (unit is millisecond) + * @param weightedSocketRetries the number of times a weighted socket will attempt to retry when + * it receives an error before reconnecting. The default is 5 times. + * @param weightedSocketBackOff the duration a a weighted socket will add to each retry attempt. + * @param weightedSocketMaxBackOff the max duration a weighted socket will delay before retrying + * to connect. The default is 5 seconds. */ private LoadBalancedRSocketMono( Publisher> factories, @@ -128,16 +127,19 @@ private LoadBalancedRSocketMono( double maxPendings, int minAperture, int maxAperture, - long maxRefreshPeriodMs) { + long maxRefreshPeriodMs, + long weightedSocketRetries, + Duration weightedSocketBackOff, + Duration weightedSocketMaxBackOff) { + this.weightedSocketRetries = weightedSocketRetries; + this.weightedSocketBackOff = weightedSocketBackOff; + this.weightedSocketMaxBackOff = weightedSocketMaxBackOff; this.expFactor = expFactor; this.lowerQuantile = new FrugalQuantile(lowQuantile); this.higherQuantile = new FrugalQuantile(highQuantile); this.activeSockets = new ArrayList<>(); - this.activeFactories = new ArrayList<>(); this.pendingSockets = 0; - this.factoryRefresher = new FactoriesRefresher(); - this.selectSocket = Mono.fromCallable(this::select); this.minPendings = minPendings; this.maxPendings = maxPendings; @@ -151,15 +153,12 @@ private LoadBalancedRSocketMono( this.lastApertureRefresh = Clock.now(); this.refreshPeriod = Clock.unit().convert(15L, TimeUnit.SECONDS); this.lastRefresh = Clock.now(); + this.pool = new RSocketSupplierPool(factories); + refreshSockets(); - factories.subscribe(factoryRefresher); + rSocketMono = Mono.fromSupplier(this::select); - rSocketMono = - Mono.create( - sink -> { - RSocket rSocket = select(); - sink.success(rSocket); - }); + onClose.doFinally(signalType -> pool.dispose()).subscribe(); } public static LoadBalancedRSocketMono create( @@ -176,6 +175,39 @@ public static LoadBalancedRSocketMono create( DEFAULT_MAX_REFRESH_PERIOD_MS); } + public static LoadBalancedRSocketMono create( + Publisher> factories, + double expFactor, + double lowQuantile, + double highQuantile, + double minPendings, + double maxPendings, + int minAperture, + int maxAperture, + long maxRefreshPeriodMs, + long weightedSocketRetries, + Duration weightedSocketBackOff, + Duration weightedSocketMaxBackOff) { + return new LoadBalancedRSocketMono( + factories, + expFactor, + lowQuantile, + highQuantile, + minPendings, + maxPendings, + minAperture, + maxAperture, + maxRefreshPeriodMs, + weightedSocketRetries, + weightedSocketBackOff, + weightedSocketMaxBackOff) { + @Override + public void subscribe(CoreSubscriber s) { + rSocketMono.subscribe(s); + } + }; + } + public static LoadBalancedRSocketMono create( Publisher> factories, double expFactor, @@ -195,72 +227,65 @@ public static LoadBalancedRSocketMono create( maxPendings, minAperture, maxAperture, - maxRefreshPeriodMs) { + maxRefreshPeriodMs, + 5, + Duration.ofMillis(500), + Duration.ofSeconds(5)) { @Override public void subscribe(CoreSubscriber s) { - started.thenMany(rSocketMono).subscribe(s); + rSocketMono.subscribe(s); } }; } + /** + * Responsible for: - refreshing the aperture - asynchronously adding/removing reactive sockets to + * match targetAperture - periodically append a new connection + */ + private synchronized void refreshSockets() { + refreshAperture(); + int n = activeSockets.size(); + if (n < targetAperture && !pool.isPoolEmpty()) { + logger.debug( + "aperture {} is below target {}, adding {} sockets", + n, + targetAperture, + targetAperture - n); + addSockets(targetAperture - n); + } else if (targetAperture < activeSockets.size()) { + logger.debug("aperture {} is above target {}, quicking 1 socket", n, targetAperture); + quickSlowestRS(); + } + + long now = Clock.now(); + if (now - lastRefresh >= refreshPeriod) { + long prev = refreshPeriod; + refreshPeriod = (long) Math.min(refreshPeriod * 1.5, maxRefreshPeriod); + logger.debug("Bumping refresh period, {}->{}", prev / 1000, refreshPeriod / 1000); + lastRefresh = now; + addSockets(1); + } + } + private synchronized void addSockets(int numberOfNewSocket) { int n = numberOfNewSocket; - if (n > activeFactories.size()) { - n = activeFactories.size(); + int poolSize = pool.poolSize(); + if (n > poolSize) { + n = poolSize; logger.debug( "addSockets({}) restricted by the number of factories, i.e. addSockets({})", numberOfNewSocket, n); } - Random rng = ThreadLocalRandom.current(); - while (n > 0) { - int size = activeFactories.size(); - if (size == 1) { - RSocketSupplier factory = activeFactories.get(0); - if (factory.availability() > 0.0) { - activeFactories.remove(0); - pendingSockets++; - factory.get().subscribe(new SocketAdder(factory)); - } - break; - } - RSocketSupplier factory0 = null; - RSocketSupplier factory1 = null; - int i0 = 0; - int i1 = 0; - for (int i = 0; i < EFFORT; i++) { - i0 = rng.nextInt(size); - i1 = rng.nextInt(size - 1); - if (i1 >= i0) { - i1++; - } - factory0 = activeFactories.get(i0); - factory1 = activeFactories.get(i1); - if (factory0.availability() > 0.0 && factory1.availability() > 0.0) { - break; - } - } + for (int i = 0; i < n; i++) { + Optional optional = pool.get(); - if (factory0.availability() < factory1.availability()) { - n--; - pendingSockets++; - // cheaper to permute activeFactories.get(i1) with the last item and remove the last - // rather than doing a activeFactories.remove(i1) - if (i1 < size - 1) { - activeFactories.set(i1, activeFactories.get(size - 1)); - } - activeFactories.remove(size - 1); - factory1.get().subscribe(new SocketAdder(factory1)); + if (optional.isPresent()) { + RSocketSupplier supplier = optional.get(); + WeightedSocket socket = new WeightedSocket(supplier, lowerQuantile, higherQuantile); } else { - n--; - pendingSockets++; - // c.f. above - if (i0 < size - 1) { - activeFactories.set(i0, activeFactories.get(size - 1)); - } - activeFactories.remove(size - 1); - factory0.get().subscribe(new SocketAdder(factory0)); + break; } } } @@ -298,7 +323,7 @@ private void updateAperture(int newValue, long now) { int previous = targetAperture; targetAperture = newValue; targetAperture = Math.max(minAperture, targetAperture); - int maxAperture = Math.min(this.maxAperture, activeSockets.size() + activeFactories.size()); + int maxAperture = Math.min(this.maxAperture, activeSockets.size() + pool.poolSize()); targetAperture = Math.min(maxAperture, targetAperture); lastApertureRefresh = now; pendings.reset((minPendings + maxPendings) / 2); @@ -312,35 +337,6 @@ private void updateAperture(int newValue, long now) { } } - /** - * Responsible for: - refreshing the aperture - asynchronously adding/removing reactive sockets to - * match targetAperture - periodically append a new connection - */ - private synchronized void refreshSockets() { - refreshAperture(); - int n = pendingSockets + activeSockets.size(); - if (n < targetAperture && !activeFactories.isEmpty()) { - logger.debug( - "aperture {} is below target {}, adding {} sockets", - n, - targetAperture, - targetAperture - n); - addSockets(targetAperture - n); - } else if (targetAperture < activeSockets.size()) { - logger.debug("aperture {} is above target {}, quicking 1 socket", n, targetAperture); - quickSlowestRS(); - } - - long now = Clock.now(); - if (now - lastRefresh >= refreshPeriod) { - long prev = refreshPeriod; - refreshPeriod = (long) Math.min(refreshPeriod * 1.5, maxRefreshPeriod); - logger.debug("Bumping refresh period, {}->{}", prev / 1000, refreshPeriod / 1000); - lastRefresh = now; - addSockets(1); - } - } - private synchronized void quickSlowestRS() { if (activeSockets.size() <= 1) { return; @@ -364,21 +360,8 @@ private synchronized void quickSlowestRS() { } if (slowest != null) { - removeSocket(slowest, false); - } - } - - private synchronized void removeSocket(WeightedSocket socket, boolean refresh) { - try { - logger.debug("Removing socket: -> " + socket); - activeSockets.remove(socket); - activeFactories.add(socket.getFactory()); - socket.close().subscribe(); - if (refresh) { - refreshSockets(); - } - } catch (Exception e) { - logger.warn("Exception while closing a RSocket", e); + logger.debug("Disposing slowest WeightedSocket {}", slowest); + slowest.dispose(); } } @@ -396,10 +379,11 @@ public synchronized double availability() { } private synchronized RSocket select() { + refreshSockets(); + if (activeSockets.isEmpty()) { return FAILING_REACTIVE_SOCKET; } - refreshSockets(); int size = activeSockets.size(); if (size == 1) { @@ -421,7 +405,7 @@ private synchronized RSocket select() { if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { break; } - if (i + 1 == EFFORT && !activeFactories.isEmpty()) { + if (i + 1 == EFFORT && !pool.isPoolEmpty()) { addSockets(1); } } @@ -468,7 +452,7 @@ public synchronized String toString() { return "LoadBalancer(a:" + activeSockets.size() + ", f: " - + activeFactories.size() + + pool.poolSize() + ", avgPendings=" + pendings.value() + ", targetAperture=" @@ -481,205 +465,33 @@ public synchronized String toString() { } @Override - public Mono onClose() { - return onClose; + public void dispose() { + synchronized (this) { + activeSockets.forEach(WeightedSocket::dispose); + activeSockets.clear(); + onClose.onComplete(); + } } @Override - public Mono close() { - return Mono.from( - subscriber -> { - subscriber.onSubscribe(Operators.emptySubscription()); - - synchronized (this) { - factoryRefresher.close(); - activeFactories.clear(); - AtomicInteger n = new AtomicInteger(activeSockets.size()); - - activeSockets.forEach( - rs -> - rs.close() - .subscribe( - new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Void aVoid) {} - - @Override - public void onError(Throwable t) { - logger.warn("Exception while closing a RSocket", t); - onComplete(); - } - - @Override - public void onComplete() { - if (n.decrementAndGet() == 0) { - subscriber.onComplete(); - onClose.onComplete(); - } - } - })); - } - }); + public boolean isDisposed() { + return onClose.isDisposed(); } - /** - * This subscriber role is to subscribe to the list of server identifier, and update the factory - * list. - */ - private class FactoriesRefresher implements Subscriber> { - private Subscription subscription; - - @Override - public void onSubscribe(Subscription subscription) { - this.subscription = subscription; - subscription.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Collection newFactories) { - synchronized (LoadBalancedRSocketMono.this) { - Set current = new HashSet<>(activeFactories.size() + activeSockets.size()); - current.addAll(activeFactories); - for (WeightedSocket socket : activeSockets) { - RSocketSupplier factory = socket.getFactory(); - current.add(factory); - } - - Set removed = new HashSet<>(current); - removed.removeAll(newFactories); - - Set added = new HashSet<>(newFactories); - added.removeAll(current); - - boolean changed = false; - Iterator it0 = activeSockets.iterator(); - while (it0.hasNext()) { - WeightedSocket socket = it0.next(); - if (removed.contains(socket.getFactory())) { - it0.remove(); - try { - changed = true; - socket.close(); - } catch (Exception e) { - logger.warn("Exception while closing a RSocket", e); - } - } - } - Iterator it1 = activeFactories.iterator(); - while (it1.hasNext()) { - RSocketSupplier factory = it1.next(); - if (removed.contains(factory)) { - it1.remove(); - changed = true; - } - } - - activeFactories.addAll(added); - - if (changed && logger.isDebugEnabled()) { - StringBuilder msgBuilder = new StringBuilder(); - msgBuilder - .append("\nUpdated active factories (size: ") - .append(activeFactories.size()) - .append(")\n"); - for (RSocketSupplier f : activeFactories) { - msgBuilder.append(" + ").append(f).append('\n'); - } - msgBuilder.append("Active sockets:\n"); - for (WeightedSocket socket : activeSockets) { - msgBuilder.append(" + ").append(socket).append('\n'); - } - logger.debug(msgBuilder.toString()); - } - } - refreshSockets(); - } - - @Override - public void onError(Throwable t) { - // TODO: retry - logger.error("Error refreshing RSocket factories. They would no longer be refreshed.", t); - } - - @Override - public void onComplete() { - // TODO: retry - logger.warn("RSocket factories source completed. They would no longer be refreshed."); - } - - void close() { - subscription.cancel(); - } - } - - private class SocketAdder implements Subscriber { - private final RSocketSupplier factory; - - private int errors; - - private SocketAdder(RSocketSupplier factory) { - this.factory = factory; - } - - @Override - public void onSubscribe(Subscription s) { - s.request(1L); - } - - @Override - public void onNext(RSocket rs) { - synchronized (LoadBalancedRSocketMono.this) { - if (activeSockets.size() >= targetAperture) { - quickSlowestRS(); - } - - WeightedSocket weightedSocket = - new WeightedSocket(rs, factory, lowerQuantile, higherQuantile); - logger.debug("Adding new WeightedSocket {}", weightedSocket); - - activeSockets.add(weightedSocket); - started.onComplete(); - pendingSockets -= 1; - } - } - - @Override - public void onError(Throwable t) { - logger.warn("Exception while subscribing to the RSocket source", t); - synchronized (LoadBalancedRSocketMono.this) { - pendingSockets -= 1; - if (++errors < 5) { - activeFactories.add(factory); - } else { - logger.warn( - "Exception count greater than 5, not re-adding factory {}", factory.toString()); - } - } - } - - @Override - public void onComplete() {} + @Override + public Mono onClose() { + return onClose; } - private static final FailingRSocket FAILING_REACTIVE_SOCKET = new FailingRSocket(); - /** * (Null Object Pattern) This failing RSocket never succeed, it is useful for simplifying the code * when dealing with edge cases. */ private static class FailingRSocket implements RSocket { - @SuppressWarnings("ThrowableInstanceNeverThrown") - private static final NoAvailableRSocketException NO_AVAILABLE_RS_EXCEPTION = - noStacktrace(new NoAvailableRSocketException()); - - private static final Mono errorVoid = Mono.error(NO_AVAILABLE_RS_EXCEPTION); - private static final Mono errorPayload = Mono.error(NO_AVAILABLE_RS_EXCEPTION); + private static final Mono errorVoid = Mono.error(NoAvailableRSocketException.INSTANCE); + private static final Mono errorPayload = + Mono.error(NoAvailableRSocketException.INSTANCE); @Override public Mono fireAndForget(Payload payload) { @@ -712,8 +524,11 @@ public double availability() { } @Override - public Mono close() { - return Mono.empty(); + public void dispose() {} + + @Override + public boolean isDisposed() { + return true; } @Override @@ -726,15 +541,13 @@ public Mono onClose() { * Wrapper of a RSocket, it computes statistics about the req/resp calls and update availability * accordingly. */ - private class WeightedSocket extends RSocketProxy implements LoadBalancerSocketMetrics { + private class WeightedSocket implements LoadBalancerSocketMetrics, RSocket { private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; - - private RSocketSupplier factory; private final Quantile lowerQuantile; private final Quantile higherQuantile; private final long inactivityFactor; - + private final MonoProcessor rSocketMono; private volatile int pending; // instantaneous rate private long stamp; // last timestamp we sent a request private long stamp0; // last timestamp we sent a request or receive a response @@ -745,14 +558,15 @@ private class WeightedSocket extends RSocketProxy implements LoadBalancerSocketM private AtomicLong pendingStreams; // number of active streams + private volatile double availability = 0.0; + private final MonoProcessor onClose = MonoProcessor.create(); + WeightedSocket( - RSocket child, RSocketSupplier factory, Quantile lowerQuantile, Quantile higherQuantile, int inactivityFactor) { - super(child); - this.factory = factory; + this.rSocketMono = MonoProcessor.create(); this.lowerQuantile = lowerQuantile; this.higherQuantile = higherQuantile; this.inactivityFactor = inactivityFactor; @@ -764,53 +578,161 @@ private class WeightedSocket extends RSocketProxy implements LoadBalancerSocketM this.median = new Median(); this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); this.pendingStreams = new AtomicLong(); - child.onClose().doFinally(signalType -> removeSocket(this, true)).subscribe(); + + logger.debug("Creating WeightedSocket {} from factory {}", WeightedSocket.this, factory); + + WeightedSocket.this + .onClose() + .doFinally( + s -> { + pool.accept(factory); + activeSockets.remove(WeightedSocket.this); + logger.debug( + "Removed {} from factory {} from activeSockets", WeightedSocket.this, factory); + }) + .subscribe(); + + factory + .get() + .retryWhen( + Retry.backoff(weightedSocketRetries, weightedSocketBackOff) + .maxBackoff(weightedSocketMaxBackOff)) + .doOnError( + throwable -> { + logger.error( + "error while connecting {} from factory {}", + WeightedSocket.this, + factory, + throwable); + WeightedSocket.this.dispose(); + }) + .subscribe( + rSocket -> { + // When RSocket is closed, close the WeightedSocket + rSocket + .onClose() + .doFinally( + signalType -> { + logger.info( + "RSocket {} from factory {} closed", WeightedSocket.this, factory); + WeightedSocket.this.dispose(); + }) + .subscribe(); + + // When the factory is closed, close the RSocket + factory + .onClose() + .doFinally( + signalType -> { + logger.info("Factory {} closed", factory); + rSocket.dispose(); + }) + .subscribe(); + + // When the WeightedSocket is closed, close the RSocket + WeightedSocket.this + .onClose() + .doFinally( + signalType -> { + logger.info( + "WeightedSocket {} from factory {} closed", + WeightedSocket.this, + factory); + rSocket.dispose(); + }) + .subscribe(); + + /*synchronized (LoadBalancedRSocketMono.this) { + if (activeSockets.size() >= targetAperture) { + quickSlowestRS(); + pendingSockets -= 1; + } + }*/ + rSocketMono.onNext(rSocket); + availability = 1.0; + if (!WeightedSocket.this + .isDisposed()) { // May be already disposed because of retryBackoff delay + activeSockets.add(WeightedSocket.this); + logger.debug( + "Added WeightedSocket {} from factory {} to activeSockets", + WeightedSocket.this, + factory); + } + }); } - WeightedSocket( - RSocket child, RSocketSupplier factory, Quantile lowerQuantile, Quantile higherQuantile) { - this(child, factory, lowerQuantile, higherQuantile, DEFAULT_INTER_ARRIVAL_FACTOR); + WeightedSocket(RSocketSupplier factory, Quantile lowerQuantile, Quantile higherQuantile) { + this(factory, lowerQuantile, higherQuantile, DEFAULT_INTER_ARRIVAL_FACTOR); } @Override public Mono requestResponse(Payload payload) { - return Mono.from( - subscriber -> - source.requestResponse(payload).subscribe(new LatencySubscriber<>(subscriber, this))); + return rSocketMono.flatMap( + source -> + Mono.from( + subscriber -> + source + .requestResponse(payload) + .subscribe( + new LatencySubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } @Override public Flux requestStream(Payload payload) { - return Flux.from( - subscriber -> - source.requestStream(payload).subscribe(new CountingSubscriber<>(subscriber, this))); + + return rSocketMono.flatMapMany( + source -> + Flux.from( + subscriber -> + source + .requestStream(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } @Override public Mono fireAndForget(Payload payload) { - return Mono.from( - subscriber -> - source.fireAndForget(payload).subscribe(new CountingSubscriber<>(subscriber, this))); + + return rSocketMono.flatMap( + source -> { + return Mono.from( + subscriber -> + source + .fireAndForget(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); + }); } @Override public Mono metadataPush(Payload payload) { - return Mono.from( - subscriber -> - source.metadataPush(payload).subscribe(new CountingSubscriber<>(subscriber, this))); + return rSocketMono.flatMap( + source -> { + return Mono.from( + subscriber -> + source + .metadataPush(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); + }); } @Override public Flux requestChannel(Publisher payloads) { - return Flux.from( - subscriber -> - source - .requestChannel(payloads) - .subscribe(new CountingSubscriber<>(subscriber, this))); - } - RSocketSupplier getFactory() { - return factory; + return rSocketMono.flatMapMany( + source -> + Flux.from( + subscriber -> + source + .requestChannel(payloads) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); } synchronized double getPredictedLatency() { @@ -880,8 +802,23 @@ private synchronized void observe(double rtt) { } @Override - public Mono close() { - return source.close(); + public double availability() { + return availability; + } + + @Override + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; } @Override @@ -901,8 +838,7 @@ public String toString() { + pending + " availability= " + availability() - + ")->" - + source; + + ")->"; } @Override @@ -939,18 +875,23 @@ public long lastTimeUsedMillis() { * Subscriber wrapper used for request/response interaction model, measure and collect latency * information. */ - private class LatencySubscriber implements Subscriber { - private final Subscriber child; + private class LatencySubscriber implements CoreSubscriber { + private final CoreSubscriber child; private final WeightedSocket socket; private final AtomicBoolean done; private long start; - LatencySubscriber(Subscriber child, WeightedSocket socket) { + LatencySubscriber(CoreSubscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; this.done = new AtomicBoolean(false); } + @Override + public Context currentContext() { + return child.currentContext(); + } + @Override public void onSubscribe(Subscription s) { start = incr(); @@ -982,7 +923,7 @@ public void onError(Throwable t) { child.onError(t); long now = decr(start); if (t instanceof TransportException || t instanceof ClosedChannelException) { - removeSocket(socket, true); + socket.dispose(); } else if (t instanceof TimeoutException) { observe(now - start); } @@ -1003,15 +944,20 @@ public void onComplete() { * Subscriber wrapper used for stream like interaction model, it only counts the number of * active streams */ - private class CountingSubscriber implements Subscriber { - private final Subscriber child; + private class CountingSubscriber implements CoreSubscriber { + private final CoreSubscriber child; private final WeightedSocket socket; - CountingSubscriber(Subscriber child, WeightedSocket socket) { + CountingSubscriber(CoreSubscriber child, WeightedSocket socket) { this.child = child; this.socket = socket; } + @Override + public Context currentContext() { + return child.currentContext(); + } + @Override public void onSubscribe(Subscription s) { socket.pendingStreams.incrementAndGet(); @@ -1028,7 +974,8 @@ public void onError(Throwable t) { socket.pendingStreams.decrementAndGet(); child.onError(t); if (t instanceof TransportException || t instanceof ClosedChannelException) { - removeSocket(socket, true); + logger.debug("Disposing {} from activeSockets because of error {}", socket, t); + socket.dispose(); } } diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java index a4ec8571e..0cb35d180 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java @@ -1,20 +1,24 @@ /* - * Copyright 2017 Netflix, Inc. - *

- * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - *

- * http://www.apache.org/licenses/LICENSE-2.0 - *

- * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.client; import io.rsocket.Availability; +@Deprecated /** A contract for the metrics managed by {@link LoadBalancedRSocketMono} per socket. */ public interface LoadBalancerSocketMetrics extends Availability { diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java new file mode 100644 index 000000000..295d25d75 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +@Deprecated +/** An exception that indicates that no RSocket was available. */ +public final class NoAvailableRSocketException extends Exception { + + /** + * The single instance of this type. Note that it is initialized without any stack trace. + */ + public static final NoAvailableRSocketException INSTANCE; + + private static final long serialVersionUID = -2785312562743351184L; + + static { + NoAvailableRSocketException exception = new NoAvailableRSocketException(); + exception.setStackTrace( + new StackTraceElement[] { + new StackTraceElement(exception.getClass().getName(), "", null, -1) + }); + + INSTANCE = exception; + } + + private NoAvailableRSocketException() {}; +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java new file mode 100644 index 000000000..8249083ad --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java @@ -0,0 +1,197 @@ +package io.rsocket.client; + +import io.rsocket.Closeable; +import io.rsocket.client.filter.RSocketSupplier; +import java.time.Duration; +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +@Deprecated +public class RSocketSupplierPool + implements Supplier>, Consumer, Closeable { + private static final Logger logger = LoggerFactory.getLogger(RSocketSupplierPool.class); + private static final int EFFORT = 5; + + private final ArrayList factoryPool; + private final ArrayList leasedSuppliers; + + private final MonoProcessor onClose; + + public RSocketSupplierPool(Publisher> publisher) { + this.onClose = MonoProcessor.create(); + this.factoryPool = new ArrayList<>(); + this.leasedSuppliers = new ArrayList<>(); + + Disposable disposable = + Flux.from(publisher) + .doOnNext(this::handleNewFactories) + .onErrorResume( + t -> { + logger.error("error streaming RSocketSuppliers", t); + return Mono.delay(Duration.ofSeconds(10)).then(Mono.error(t)); + }) + .subscribe(); + + onClose.doFinally(s -> disposable.dispose()).subscribe(); + } + + private synchronized void handleNewFactories(Collection newFactories) { + Set current = new HashSet<>(factoryPool.size() + leasedSuppliers.size()); + current.addAll(factoryPool); + current.addAll(leasedSuppliers); + + Set removed = new HashSet<>(current); + removed.removeAll(newFactories); + + Set added = new HashSet<>(newFactories); + added.removeAll(current); + + boolean changed = false; + Iterator it0 = leasedSuppliers.iterator(); + while (it0.hasNext()) { + RSocketSupplier supplier = it0.next(); + if (removed.contains(supplier)) { + it0.remove(); + try { + changed = true; + supplier.dispose(); + } catch (Exception e) { + logger.warn("Exception while closing a RSocket", e); + } + } + } + + Iterator it1 = factoryPool.iterator(); + while (it1.hasNext()) { + RSocketSupplier supplier = it1.next(); + if (removed.contains(supplier)) { + it1.remove(); + try { + changed = true; + supplier.dispose(); + } catch (Exception e) { + logger.warn("Exception while closing a RSocket", e); + } + } + } + + factoryPool.addAll(added); + if (!added.isEmpty()) { + changed = true; + } + + if (changed && logger.isDebugEnabled()) { + StringBuilder msgBuilder = new StringBuilder(); + msgBuilder + .append("\nUpdated active factories (size: ") + .append(factoryPool.size()) + .append(")\n"); + for (RSocketSupplier f : factoryPool) { + msgBuilder.append(" + ").append(f).append('\n'); + } + msgBuilder.append("Active sockets:\n"); + for (RSocketSupplier socket : leasedSuppliers) { + msgBuilder.append(" + ").append(socket).append('\n'); + } + logger.debug(msgBuilder.toString()); + } + } + + @Override + public synchronized void accept(RSocketSupplier rSocketSupplier) { + boolean contained = leasedSuppliers.remove(rSocketSupplier); + if (contained + && !rSocketSupplier + .isDisposed()) { // only added leasedSupplier back to factoryPool if it's still there + factoryPool.add(rSocketSupplier); + } + } + + @Override + public synchronized Optional get() { + Optional optional = Optional.empty(); + int poolSize = factoryPool.size(); + if (poolSize == 1) { + RSocketSupplier rSocketSupplier = factoryPool.get(0); + if (rSocketSupplier.availability() > 0.0) { + factoryPool.remove(0); + leasedSuppliers.add(rSocketSupplier); + logger.debug("Added {} to leasedSuppliers", rSocketSupplier); + optional = Optional.of(rSocketSupplier); + } + } else if (poolSize > 1) { + Random rng = ThreadLocalRandom.current(); + int size = factoryPool.size(); + RSocketSupplier factory0 = null; + RSocketSupplier factory1 = null; + int i0 = 0; + int i1 = 0; + for (int i = 0; i < EFFORT; i++) { + i0 = rng.nextInt(size); + i1 = rng.nextInt(size - 1); + if (i1 >= i0) { + i1++; + } + factory0 = factoryPool.get(i0); + factory1 = factoryPool.get(i1); + if (factory0.availability() > 0.0 && factory1.availability() > 0.0) { + break; + } + } + if (factory0.availability() > factory1.availability()) { + factoryPool.remove(i0); + leasedSuppliers.add(factory0); + logger.debug("Added {} to leasedSuppliers", factory0); + optional = Optional.of(factory0); + } else { + factoryPool.remove(i1); + leasedSuppliers.add(factory1); + logger.debug("Added {} to leasedSuppliers", factory1); + optional = Optional.of(factory1); + } + } + + return optional; + } + + @Override + public Mono onClose() { + return onClose; + } + + @Override + public void dispose() { + if (!onClose.isDisposed()) { + onClose.onComplete(); + + close(factoryPool); + close(leasedSuppliers); + } + } + + private void close(Collection suppliers) { + for (RSocketSupplier supplier : suppliers) { + try { + supplier.dispose(); + } catch (Throwable t) { + } + } + } + + public synchronized int poolSize() { + return factoryPool.size(); + } + + public synchronized boolean isPoolEmpty() { + return factoryPool.isEmpty(); + } +} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/NotConnectedException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java similarity index 58% rename from rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/NotConnectedException.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java index f86534daf..a32ac2224 100644 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/NotConnectedException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,17 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.aeron.internal; -public class NotConnectedException extends RuntimeException { +package io.rsocket.client; - private static final long serialVersionUID = -5521573868855763403L; +@Deprecated +public final class TimeoutException extends Exception { - public NotConnectedException() { - super(); - } + public static final TimeoutException INSTANCE = new TimeoutException(); - public NotConnectedException(String message) { - super(message); - } + private static final long serialVersionUID = -3094321310317812063L; + + private TimeoutException() {} } diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/TransportException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java similarity index 68% rename from rsocket-core/src/main/java/io/rsocket/exceptions/TransportException.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java index af18dd35e..4779c6d4d 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/TransportException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.exceptions; -public class TransportException extends Throwable { +package io.rsocket.client; - private static final long serialVersionUID = 7541914004190564240L; +@Deprecated +public final class TransportException extends Throwable { + + private static final long serialVersionUID = -3339846338318701123L; public TransportException(Throwable t) { super(t); diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java index 745ded482..beb424797 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.client.filter; import io.rsocket.Payload; @@ -32,6 +33,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +@Deprecated public class BackupRequestSocket implements RSocket { private final ScheduledExecutorService executor; private final RSocket child; @@ -88,8 +90,13 @@ public double availability() { } @Override - public Mono close() { - return child.close(); + public void dispose() { + child.dispose(); + } + + @Override + public boolean isDisposed() { + return child.isDisposed(); } @Override diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java index 29617188b..aaf9f71e6 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java @@ -1,17 +1,17 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.client.filter; @@ -31,6 +31,7 @@ import reactor.core.publisher.MonoProcessor; /** */ +@Deprecated public class RSocketSupplier implements Availability, Supplier>, Closeable { private static final double EPSILON = 1e-4; @@ -87,8 +88,13 @@ public Mono get() { } @Override - public Mono close() { - return Mono.empty().doFinally(s -> onClose.onComplete()).then(); + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); } @Override @@ -100,7 +106,7 @@ private class AvailabilityAwareRSocketProxy extends RSocketProxy { public AvailabilityAwareRSocketProxy(RSocket source) { super(source); - onClose.then(close()).subscribe(); + onClose.doFinally(signalType -> source.dispose()).subscribe(); } @Override diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java index 76041bca1..89ff74143 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -27,6 +27,7 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +@Deprecated public final class RSockets { private RSockets() { @@ -72,7 +73,7 @@ public Mono metadataPush(Payload payload) { /** * Provides a mapping function to wrap a {@code RSocket} such that a call to {@link - * RSocket#close()} does not cancel all pending requests. Instead, it will wait for all pending + * RSocket#dispose()} does not cancel all pending requests. Instead, it will wait for all pending * requests to finish and then close the socket. * * @return Function to transform any socket into a safe closing socket. @@ -91,7 +92,7 @@ public Mono fireAndForget(Payload payload) { .doFinally( signalType -> { if (count.decrementAndGet() == 0 && closed.get()) { - source.close().subscribe(); + source.dispose(); } }); } @@ -104,7 +105,7 @@ public Mono requestResponse(Payload payload) { .doFinally( signalType -> { if (count.decrementAndGet() == 0 && closed.get()) { - source.close().subscribe(); + source.dispose(); } }); } @@ -117,7 +118,7 @@ public Flux requestStream(Payload payload) { .doFinally( signalType -> { if (count.decrementAndGet() == 0 && closed.get()) { - source.close().subscribe(); + source.dispose(); } }); } @@ -130,7 +131,7 @@ public Flux requestChannel(Publisher payloads) { .doFinally( signalType -> { if (count.decrementAndGet() == 0 && closed.get()) { - source.close().subscribe(); + source.dispose(); } }); } @@ -143,24 +144,18 @@ public Mono metadataPush(Payload payload) { .doFinally( signalType -> { if (count.decrementAndGet() == 0 && closed.get()) { - source.close().subscribe(); + source.dispose(); } }); } @Override - public Mono close() { - return Mono.defer( - () -> { - if (closed.compareAndSet(false, true)) { - if (count.get() == 0) { - return source.close(); - } else { - return source.onClose(); - } - } - return source.onClose(); - }); + public void dispose() { + if (closed.compareAndSet(false, true)) { + if (count.get() == 0) { + source.dispose(); + } + } } }; } diff --git a/rsocket-transport-aeron/build.gradle b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java similarity index 70% rename from rsocket-transport-aeron/build.gradle rename to rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java index f20e7cb44..55ce5646c 100644 --- a/rsocket-transport-aeron/build.gradle +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,9 +14,7 @@ * limitations under the License. */ -dependencies { - compile project(':rsocket-core') - compile "io.aeron:aeron-all:1.4.1" +@NonNullApi +package io.rsocket.client.filter; - testCompile project(':rsocket-test') -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/TimeoutException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java similarity index 69% rename from rsocket-core/src/main/java/io/rsocket/exceptions/TimeoutException.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java index 8198ee743..ec21dee96 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/TimeoutException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,9 +14,7 @@ * limitations under the License. */ -package io.rsocket.exceptions; +@NonNullApi +package io.rsocket.client; -public class TimeoutException extends Exception { - - private static final long serialVersionUID = -6352901497935205059L; -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java index 06ab60276..3968ec0a4 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.stat; import io.rsocket.util.Clock; @@ -26,6 +27,7 @@ *

e.g. with a half-life of 10 unit, if you insert 100 at t=0 and 200 at t=10 the ewma will be * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) */ +@Deprecated public class Ewma { private final long tau; private volatile long stamp; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java index 3c83f422f..99c12e801 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.stat; import java.util.Random; @@ -24,6 +25,7 @@ * *

More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ */ +@Deprecated public class FrugalQuantile implements Quantile { private final double increment; private double quantile; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java index 8256af70f..00dd69de9 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.stat; /** This implementation gives better results because it considers more data-point. */ +@Deprecated public class Median extends FrugalQuantile { public Median() { super(0.5, 1.0, null); diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java index f2a341b26..aa3667e8f 100644 --- a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,6 +15,7 @@ */ package io.rsocket.stat; +@Deprecated public interface Quantile { /** @return the estimation of the current value of the quantile */ double estimation(); diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/NoAvailableRSocketException.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java similarity index 68% rename from rsocket-core/src/main/java/io/rsocket/exceptions/NoAvailableRSocketException.java rename to rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java index 40a2de99a..cfb071175 100644 --- a/rsocket-core/src/main/java/io/rsocket/exceptions/NoAvailableRSocketException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.exceptions; -public class NoAvailableRSocketException extends Exception { +@NonNullApi +package io.rsocket.stat; - private static final long serialVersionUID = 7608370678694273507L; -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java index d1a98775d..52bf89558 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,45 +16,41 @@ package io.rsocket.client; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.client.filter.RSocketSupplier; -import io.rsocket.util.PayloadImpl; -import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.Arrays; +import java.util.Collections; import java.util.List; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CompletableFuture; import java.util.function.Function; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.mockito.Mockito; import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class LoadBalancedRSocketMonoTest { - @Test(timeout = 10_000L) + @Test + @Timeout(10_000L) public void testNeverSelectFailingFactories() throws InterruptedException { - InetSocketAddress local0 = InetSocketAddress.createUnresolved("localhost", 7000); - InetSocketAddress local1 = InetSocketAddress.createUnresolved("localhost", 7001); - TestingRSocket socket = new TestingRSocket(Function.identity()); - RSocketSupplier failing = failingClient(local0); + RSocketSupplier failing = failingClient(); RSocketSupplier succeeding = succeedingFactory(socket); List factories = Arrays.asList(failing, succeeding); testBalancer(factories); } - @Test(timeout = 10_000L) + @Test + @Timeout(10_000L) public void testNeverSelectFailingSocket() throws InterruptedException { - InetSocketAddress local0 = InetSocketAddress.createUnresolved("localhost", 7000); - InetSocketAddress local1 = InetSocketAddress.createUnresolved("localhost", 7001); - TestingRSocket socket = new TestingRSocket(Function.identity()); TestingRSocket failingSocket = new TestingRSocket(Function.identity()) { @@ -76,6 +72,35 @@ public double availability() { testBalancer(clients); } + @Test + @Timeout(10_000L) + @Disabled + public void testRefreshesSocketsOnSelectBeforeReturningFailedAfterNewFactoriesDelivered() { + TestingRSocket socket = new TestingRSocket(Function.identity()); + + CompletableFuture laterSupplier = new CompletableFuture<>(); + Flux> factories = + Flux.create( + s -> { + s.next(Collections.emptyList()); + + laterSupplier.handle( + (RSocketSupplier result, Throwable t) -> { + s.next(Collections.singletonList(result)); + return null; + }); + }); + + LoadBalancedRSocketMono balancer = LoadBalancedRSocketMono.create(factories); + + assertThat(balancer.availability()).isZero(); + + laterSupplier.complete(succeedingFactory(socket)); + balancer.rSocketMono.block(); + + assertThat(balancer.availability()).isEqualTo(1.0); + } + private void testBalancer(List factories) throws InterruptedException { Publisher> src = s -> { @@ -92,56 +117,24 @@ private void testBalancer(List factories) throws InterruptedExc Flux.range(0, 100).flatMap(i -> balancer).blockLast(); } - private void makeAcall(RSocket balancer) throws InterruptedException { - CountDownLatch latch = new CountDownLatch(1); - - balancer - .requestResponse(PayloadImpl.EMPTY) - .subscribe( - new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - s.request(1L); - } - - @Override - public void onNext(Payload payload) { - System.out.println("Successfully receiving a response"); - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - Assert.assertTrue(false); - latch.countDown(); - } - - @Override - public void onComplete() { - latch.countDown(); - } - }); - - latch.await(); - } - private static RSocketSupplier succeedingFactory(RSocket socket) { RSocketSupplier mock = Mockito.mock(RSocketSupplier.class); Mockito.when(mock.availability()).thenReturn(1.0); Mockito.when(mock.get()).thenReturn(Mono.just(socket)); + Mockito.when(mock.onClose()).thenReturn(Mono.never()); return mock; } - private static RSocketSupplier failingClient(SocketAddress sa) { + private static RSocketSupplier failingClient() { RSocketSupplier mock = Mockito.mock(RSocketSupplier.class); Mockito.when(mock.availability()).thenReturn(0.0); Mockito.when(mock.get()) .thenAnswer( a -> { - Assert.fail(); + fail(); return null; }); diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java index 7fd439ed1..9e1982465 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.client; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; @@ -25,12 +25,12 @@ import io.rsocket.RSocket; import io.rsocket.client.filter.RSocketSupplier; import io.rsocket.test.TestSubscriber; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.EmptyPayload; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiConsumer; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.mockito.Mockito; import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; @@ -43,8 +43,8 @@ public class RSocketSupplierTest { public void testError() throws InterruptedException { testRSocket( (latch, socket) -> { - assertEquals(1.0, socket.availability(), 0.0); - Publisher payloadPublisher = socket.requestResponse(PayloadImpl.EMPTY); + assertThat(socket.availability()).isEqualTo(1.0); + Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); Subscriber subscriber = TestSubscriber.create(); payloadPublisher.subscribe(subscriber); @@ -63,7 +63,7 @@ public void testError() throws InterruptedException { payloadPublisher.subscribe(subscriber); verify(subscriber).onError(any(RuntimeException.class)); double bad = socket.availability(); - assertTrue(good > bad); + assertThat(good > bad).isTrue(); latch.countDown(); }); } @@ -72,8 +72,8 @@ public void testError() throws InterruptedException { public void testWidowReset() throws InterruptedException { testRSocket( (latch, socket) -> { - assertEquals(1.0, socket.availability(), 0.0); - Publisher payloadPublisher = socket.requestResponse(PayloadImpl.EMPTY); + assertThat(socket.availability()).isEqualTo(1.0); + Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); Subscriber subscriber = TestSubscriber.create(); payloadPublisher.subscribe(subscriber); @@ -86,7 +86,7 @@ public void testWidowReset() throws InterruptedException { verify(subscriber).onError(any(RuntimeException.class)); double bad = socket.availability(); - assertTrue(good > bad); + assertThat(good > bad).isTrue(); try { Thread.sleep(200); @@ -95,7 +95,7 @@ public void testWidowReset() throws InterruptedException { } double reset = socket.availability(); - assertTrue(reset > bad); + assertThat(reset > bad).isTrue(); latch.countDown(); }); } @@ -106,7 +106,7 @@ private void testRSocket(BiConsumer f) throws Interrupt new TestingRSocket( input -> { if (count.getAndIncrement() < 1) { - return PayloadImpl.EMPTY; + return EmptyPayload.INSTANCE; } else { throw new RuntimeException(); } diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java index a4ca079d5..2827c8ed4 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -24,12 +24,13 @@ import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; +import reactor.core.Scannable; import reactor.core.publisher.*; public class TestingRSocket implements RSocket { private final AtomicInteger count; - private final MonoProcessor closeSubject = MonoProcessor.create(); + private final Sinks.Empty onClose = Sinks.empty(); private final BiFunction, Payload, Boolean> eachPayloadHandler; public TestingRSocket(Function responder) { @@ -127,16 +128,18 @@ public double availability() { } @Override - public Mono close() { - return Mono.defer( - () -> { - closeSubject.onComplete(); - return closeSubject; - }); + public void dispose() { + onClose.tryEmitEmpty(); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return closeSubject; + return onClose.asMono(); } } diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java index 0210341d9..b8866b1f6 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,16 +16,14 @@ package io.rsocket.client; -import static org.hamcrest.Matchers.instanceOf; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.client.filter.RSockets; -import io.rsocket.exceptions.TimeoutException; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.EmptyPayload; import java.time.Duration; -import org.hamcrest.MatcherAssert; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; @@ -36,7 +34,7 @@ public void testTimeoutSocket() { RSocket timeout = RSockets.timeout(Duration.ofMillis(50)).apply(socket); timeout - .requestResponse(PayloadImpl.EMPTY) + .requestResponse(EmptyPayload.INSTANCE) .subscribe( new Subscriber() { @Override @@ -51,8 +49,9 @@ public void onNext(Payload payload) { @Override public void onError(Throwable t) { - MatcherAssert.assertThat( - "Unexpected exception in onError", t, instanceOf(TimeoutException.class)); + assertThat(t) + .describedAs("Unexpected exception in onError") + .isInstanceOf(TimeoutException.class); } @Override diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java index 1dec437e0..b214a725e 100644 --- a/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java +++ b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,8 +18,8 @@ import java.util.Arrays; import java.util.Random; -import org.junit.Assert; -import org.junit.Test; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; public class MedianTest { private double errorSum = 0; @@ -59,7 +59,8 @@ private void testMedian(Random rng) { maxError = Math.max(maxError, error); minError = Math.min(minError, error); - Assert.assertTrue( - "p50=" + estimation + ", real=" + expected + ", error=" + error, error < 0.02); + Assertions.assertThat(error < 0.02) + .describedAs("p50=" + estimation + ", real=" + expected + ", error=" + error) + .isTrue(); } } diff --git a/rsocket-load-balancer/src/test/resources/log4j.properties b/rsocket-load-balancer/src/test/resources/log4j.properties deleted file mode 100644 index 6477d125f..000000000 --- a/rsocket-load-balancer/src/test/resources/log4j.properties +++ /dev/null @@ -1,33 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -#

-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -#

-# http://www.apache.org/licenses/LICENSE-2.0 -#

-# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. -# - - -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] (%F:%L) - %m%n \ No newline at end of file diff --git a/rsocket-load-balancer/src/test/resources/logback-test.xml b/rsocket-load-balancer/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-load-balancer/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-micrometer/build.gradle b/rsocket-micrometer/build.gradle new file mode 100644 index 000000000..debf02f34 --- /dev/null +++ b/rsocket-micrometer/build.gradle @@ -0,0 +1,47 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + +dependencies { + api project(':rsocket-core') + api 'io.micrometer:micrometer-observation' + api 'io.micrometer:micrometer-core' + api 'io.micrometer:micrometer-tracing' + + implementation 'org.slf4j:slf4j-api' + + testImplementation project(':rsocket-test') + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.mockito:mockito-core' + + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.micrometer") + } +} + +description = 'Transparent Metrics exposure to Micrometer' diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java new file mode 100644 index 000000000..7c7ac37b9 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java @@ -0,0 +1,267 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static io.rsocket.frame.FrameType.*; + +import io.micrometer.core.instrument.*; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import java.net.SocketAddress; +import java.util.Objects; +import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * An implementation of {@link DuplexConnection} that intercepts frames and gathers Micrometer + * metrics about them. + * + *

The metric is called {@code rsocket.frame} and is tagged with {@code connection.type} ({@link + * Type}), {@code frame.type} ({@link FrameType}), and any additional configured tags. {@code + * rsocket.duplex.connection.close} and {@code rsocket.duplex.connection.dispose} metrics, tagged + * with {@code connection.type} ({@link Type}) and any additional configured tags are also + * collected. + * + * @see Micrometer + */ +final class MicrometerDuplexConnection implements DuplexConnection { + + private final Counter close; + + private final DuplexConnection delegate; + + private final Counter dispose; + + private final FrameCounters frameCounters; + + /** + * Creates a new {@link DuplexConnection}. + * + * @param connectionType the type of connection being monitored + * @param delegate the {@link DuplexConnection} to delegate to + * @param meterRegistry the {@link MeterRegistry} to use + * @param tags additional tags to attach to {@link Meter}s + * @throws NullPointerException if {@code connectionType}, {@code delegate}, or {@code + * meterRegistry} is {@code null} + */ + MicrometerDuplexConnection( + Type connectionType, DuplexConnection delegate, MeterRegistry meterRegistry, Tag... tags) { + + Objects.requireNonNull(connectionType, "connectionType must not be null"); + this.delegate = Objects.requireNonNull(delegate, "delegate must not be null"); + Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + + this.close = + meterRegistry.counter( + "rsocket.duplex.connection.close", + Tags.of(tags).and("connection.type", connectionType.name())); + this.dispose = + meterRegistry.counter( + "rsocket.duplex.connection.dispose", + Tags.of(tags).and("connection.type", connectionType.name())); + this.frameCounters = new FrameCounters(connectionType, meterRegistry, tags); + } + + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return delegate.remoteAddress(); + } + + @Override + public void dispose() { + delegate.dispose(); + dispose.increment(); + } + + @Override + public Mono onClose() { + return delegate.onClose().doAfterTerminate(close::increment); + } + + @Override + public Flux receive() { + return delegate.receive().doOnNext(frameCounters); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + frameCounters.accept(frame); + delegate.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + delegate.sendErrorAndClose(e); + } + + private static final class FrameCounters implements Consumer { + + private final Logger logger = LoggerFactory.getLogger(this.getClass()); + + private final Counter cancel; + + private final Counter complete; + + private final Counter error; + + private final Counter extension; + + private final Counter keepalive; + + private final Counter lease; + + private final Counter metadataPush; + + private final Counter next; + + private final Counter nextComplete; + + private final Counter payload; + + private final Counter requestChannel; + + private final Counter requestFireAndForget; + + private final Counter requestN; + + private final Counter requestResponse; + + private final Counter requestStream; + + private final Counter resume; + + private final Counter resumeOk; + + private final Counter setup; + + private final Counter unknown; + + private FrameCounters(Type connectionType, MeterRegistry meterRegistry, Tag... tags) { + this.cancel = counter(connectionType, meterRegistry, CANCEL, tags); + this.complete = counter(connectionType, meterRegistry, COMPLETE, tags); + this.error = counter(connectionType, meterRegistry, ERROR, tags); + this.extension = counter(connectionType, meterRegistry, EXT, tags); + this.keepalive = counter(connectionType, meterRegistry, KEEPALIVE, tags); + this.lease = counter(connectionType, meterRegistry, LEASE, tags); + this.metadataPush = counter(connectionType, meterRegistry, METADATA_PUSH, tags); + this.next = counter(connectionType, meterRegistry, NEXT, tags); + this.nextComplete = counter(connectionType, meterRegistry, NEXT_COMPLETE, tags); + this.payload = counter(connectionType, meterRegistry, PAYLOAD, tags); + this.requestChannel = counter(connectionType, meterRegistry, REQUEST_CHANNEL, tags); + this.requestFireAndForget = counter(connectionType, meterRegistry, REQUEST_FNF, tags); + this.requestN = counter(connectionType, meterRegistry, REQUEST_N, tags); + this.requestResponse = counter(connectionType, meterRegistry, REQUEST_RESPONSE, tags); + this.requestStream = counter(connectionType, meterRegistry, REQUEST_STREAM, tags); + this.resume = counter(connectionType, meterRegistry, RESUME, tags); + this.resumeOk = counter(connectionType, meterRegistry, RESUME_OK, tags); + this.setup = counter(connectionType, meterRegistry, SETUP, tags); + this.unknown = counter(connectionType, meterRegistry, "UNKNOWN", tags); + } + + private static Counter counter( + Type connectionType, MeterRegistry meterRegistry, FrameType frameType, Tag... tags) { + + return counter(connectionType, meterRegistry, frameType.name(), tags); + } + + private static Counter counter( + Type connectionType, MeterRegistry meterRegistry, String frameType, Tag... tags) { + + return meterRegistry.counter( + "rsocket.frame", + Tags.of(tags).and("connection.type", connectionType.name()).and("frame.type", frameType)); + } + + @Override + public void accept(ByteBuf frame) { + FrameType frameType = FrameHeaderCodec.frameType(frame); + + switch (frameType) { + case SETUP: + this.setup.increment(); + break; + case LEASE: + this.lease.increment(); + break; + case KEEPALIVE: + this.keepalive.increment(); + break; + case REQUEST_RESPONSE: + this.requestResponse.increment(); + break; + case REQUEST_FNF: + this.requestFireAndForget.increment(); + break; + case REQUEST_STREAM: + this.requestStream.increment(); + break; + case REQUEST_CHANNEL: + this.requestChannel.increment(); + break; + case REQUEST_N: + this.requestN.increment(); + break; + case CANCEL: + this.cancel.increment(); + break; + case PAYLOAD: + this.payload.increment(); + break; + case ERROR: + this.error.increment(); + break; + case METADATA_PUSH: + this.metadataPush.increment(); + break; + case RESUME: + this.resume.increment(); + break; + case RESUME_OK: + this.resumeOk.increment(); + break; + case NEXT: + this.next.increment(); + break; + case COMPLETE: + this.complete.increment(); + break; + case NEXT_COMPLETE: + this.nextComplete.increment(); + break; + case EXT: + this.extension.increment(); + break; + default: + this.logger.debug("Skipping count of unknown frame type: {}", frameType); + this.unknown.increment(); + } + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptor.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptor.java new file mode 100644 index 000000000..b94e969ec --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptor.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import java.util.Objects; + +/** + * An implementation of {@link DuplexConnectionInterceptor} that intercepts frames and gathers + * Micrometer metrics about them. + * + *

The metric is called {@code rsocket.frame} and is tagged with {@code connection.type} ({@link + * Type}), {@code frame.type} ({@link FrameType}), and any additional configured tags. {@code + * rsocket.duplex.connection.close} and {@code rsocket.duplex.connection.dispose} metrics, tagged + * with {@code connection.type} ({@link Type}) and any additional configured tags are also + * collected. + * + * @see Micrometer + */ +public final class MicrometerDuplexConnectionInterceptor implements DuplexConnectionInterceptor { + + private final MeterRegistry meterRegistry; + + private final Tag[] tags; + + /** + * Creates a new {@link DuplexConnectionInterceptor}. + * + * @param meterRegistry the {@link MeterRegistry} to use to create {@link Meter}s. + * @param tags the additional tags to attach to each {@link Meter} + * @throws NullPointerException if {@code meterRegistry} is {@code null} + */ + public MicrometerDuplexConnectionInterceptor(MeterRegistry meterRegistry, Tag... tags) { + this.meterRegistry = Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + this.tags = tags; + } + + @Override + public MicrometerDuplexConnection apply(Type connectionType, DuplexConnection delegate) { + Objects.requireNonNull(connectionType, "connectionType must not be null"); + Objects.requireNonNull(delegate, "delegate must not be null"); + + return new MicrometerDuplexConnection(connectionType, delegate, meterRegistry, tags); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocket.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocket.java new file mode 100644 index 000000000..9e1abbc03 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocket.java @@ -0,0 +1,206 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static reactor.core.publisher.SignalType.CANCEL; +import static reactor.core.publisher.SignalType.ON_COMPLETE; +import static reactor.core.publisher.SignalType.ON_ERROR; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.Timer; +import io.micrometer.core.instrument.Timer.Sample; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; + +/** + * An implementation of {@link RSocket} that intercepts interactions and gathers Micrometer metrics + * about them. + * + *

The metrics are called {@code rsocket.[ metadata.push | request.channel | request.fnf | + * request.response | request.stream ]} and is tagged with {@code signal.type} ({@link SignalType}) + * and any additional configured tags. + * + * @see Micrometer + */ +final class MicrometerRSocket implements RSocket { + + private final RSocket delegate; + + private final InteractionCounters metadataPush; + + private final InteractionCounters requestChannel; + + private final InteractionCounters requestFireAndForget; + + private final InteractionTimers requestResponse; + + private final InteractionCounters requestStream; + + /** + * Creates a new {@link RSocket}. + * + * @param delegate the {@link RSocket} to delegate to + * @param meterRegistry the {@link MeterRegistry} to use + * @param tags additional tags to attach to {@link Meter}s + * @throws NullPointerException if {@code delegate} or {@code meterRegistry} is {@code null} + */ + MicrometerRSocket(RSocket delegate, MeterRegistry meterRegistry, Tag... tags) { + this.delegate = Objects.requireNonNull(delegate, "delegate must not be null"); + Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + + this.metadataPush = new InteractionCounters(meterRegistry, "metadata.push", tags); + this.requestChannel = new InteractionCounters(meterRegistry, "request.channel", tags); + this.requestFireAndForget = new InteractionCounters(meterRegistry, "request.fnf", tags); + this.requestResponse = new InteractionTimers(meterRegistry, "request.response", tags); + this.requestStream = new InteractionCounters(meterRegistry, "request.stream", tags); + } + + @Override + public void dispose() { + delegate.dispose(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return delegate.fireAndForget(payload).doFinally(requestFireAndForget); + } + + @Override + public Mono metadataPush(Payload payload) { + return delegate.metadataPush(payload).doFinally(metadataPush); + } + + @Override + public Mono onClose() { + return delegate.onClose(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return delegate.requestChannel(payloads).doFinally(requestChannel); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.defer( + () -> { + Sample sample = requestResponse.start(); + + return delegate + .requestResponse(payload) + .doFinally(signalType -> requestResponse.accept(sample, signalType)); + }); + } + + @Override + public Flux requestStream(Payload payload) { + return delegate.requestStream(payload).doFinally(requestStream); + } + + private static final class InteractionCounters implements Consumer { + + private final Counter cancel; + + private final Counter onComplete; + + private final Counter onError; + + private InteractionCounters(MeterRegistry meterRegistry, String interactionModel, Tag... tags) { + this.cancel = counter(meterRegistry, interactionModel, CANCEL, tags); + this.onComplete = counter(meterRegistry, interactionModel, ON_COMPLETE, tags); + this.onError = counter(meterRegistry, interactionModel, ON_ERROR, tags); + } + + @Override + public void accept(SignalType signalType) { + switch (signalType) { + case CANCEL: + cancel.increment(); + break; + case ON_COMPLETE: + onComplete.increment(); + break; + case ON_ERROR: + onError.increment(); + break; + } + } + + private static Counter counter( + MeterRegistry meterRegistry, String interactionModel, SignalType signalType, Tag... tags) { + + return meterRegistry.counter( + "rsocket." + interactionModel, Tags.of(tags).and("signal.type", signalType.name())); + } + } + + private static final class InteractionTimers implements BiConsumer { + + private final Timer cancel; + + private final MeterRegistry meterRegistry; + + private final Timer onComplete; + + private final Timer onError; + + private InteractionTimers(MeterRegistry meterRegistry, String interactionModel, Tag... tags) { + this.meterRegistry = meterRegistry; + + this.cancel = timer(meterRegistry, interactionModel, CANCEL, tags); + this.onComplete = timer(meterRegistry, interactionModel, ON_COMPLETE, tags); + this.onError = timer(meterRegistry, interactionModel, ON_ERROR, tags); + } + + @Override + public void accept(Sample sample, SignalType signalType) { + switch (signalType) { + case CANCEL: + sample.stop(cancel); + break; + case ON_COMPLETE: + sample.stop(onComplete); + break; + case ON_ERROR: + sample.stop(onError); + break; + } + } + + Sample start() { + return Timer.start(meterRegistry); + } + + private static Timer timer( + MeterRegistry meterRegistry, String interactionModel, SignalType signalType, Tag... tags) { + + return meterRegistry.timer( + "rsocket." + interactionModel, Tags.of(tags).and("signal.type", signalType.name())); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocketInterceptor.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocketInterceptor.java new file mode 100644 index 000000000..c405c8601 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocketInterceptor.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.rsocket.RSocket; +import io.rsocket.plugins.RSocketInterceptor; +import java.util.Objects; +import reactor.core.publisher.SignalType; + +/** + * An implementation of {@link RSocketInterceptor} that intercepts interactions and gathers + * Micrometer metrics about them. + * + *

The metrics are called {@code rsocket.[ metadata.push | request.channel | request.fnf | + * request.response | request.stream ]} and is tagged with {@code signal.type} ({@link SignalType}) + * and any additional configured tags. + * + * @see Micrometer + */ +public final class MicrometerRSocketInterceptor implements RSocketInterceptor { + + private final MeterRegistry meterRegistry; + + private final Tag[] tags; + + /** + * Creates a new {@link RSocketInterceptor}. + * + * @param meterRegistry the {@link MeterRegistry} to use to create {@link Meter}s. + * @param tags the additional tags to attach to each {@link Meter} + * @throws NullPointerException if {@code meterRegistry} is {@code null} + */ + public MicrometerRSocketInterceptor(MeterRegistry meterRegistry, Tag... tags) { + this.meterRegistry = Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + this.tags = tags; + } + + @Override + public MicrometerRSocket apply(RSocket delegate) { + Objects.requireNonNull(delegate, "delegate must not be null"); + + return new MicrometerRSocket(delegate, meterRegistry, tags); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java new file mode 100644 index 000000000..09c8ba316 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata; + +public class ByteBufGetter implements Propagator.Getter { + + @Override + public String get(ByteBuf carrier, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(carrier, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + if (key.equals(entry.getMimeType())) { + return entry.getContent().toString(CharsetUtil.UTF_8); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java new file mode 100644 index 000000000..678bdb1ed --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java @@ -0,0 +1,33 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadataCodec; + +public class ByteBufSetter implements Propagator.Setter { + + @Override + public void set(CompositeByteBuf carrier, String key, String value) { + final ByteBufAllocator alloc = carrier.alloc(); + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + carrier, alloc, key, ByteBufUtil.writeUtf8(alloc, value)); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java new file mode 100644 index 000000000..357be8f15 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java @@ -0,0 +1,40 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.core.lang.Nullable; +import io.netty.buffer.ByteBuf; +import io.rsocket.metadata.CompositeMetadata; + +final class CompositeMetadataUtils { + + private CompositeMetadataUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + @Nullable + static ByteBuf extract(ByteBuf metadata, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + final String entryKey = entry.getMimeType(); + if (key.equals(entryKey)) { + return entry.getContent(); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java new file mode 100644 index 000000000..2c10fc78d --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java @@ -0,0 +1,49 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +class DefaultRSocketObservationConvention { + + private final RSocketContext rSocketContext; + + public DefaultRSocketObservationConvention(RSocketContext rSocketContext) { + this.rSocketContext = rSocketContext; + } + + String getName() { + if (this.rSocketContext.frameType == FrameType.REQUEST_FNF) { + return "rsocket.fnf"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_STREAM) { + return "rsocket.stream"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_CHANNEL) { + return "rsocket.channel"; + } + return "%s"; + } + + protected RSocketContext getRSocketContext() { + return this.rSocketContext; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java new file mode 100644 index 000000000..73e04b749 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java @@ -0,0 +1,62 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketRequesterObservationConvention + extends DefaultRSocketObservationConvention implements RSocketRequesterObservationConvention { + + public DefaultRSocketRequesterObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues values = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + values = + values.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return values; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.request"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java new file mode 100644 index 000000000..5318c1b37 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java @@ -0,0 +1,61 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketResponderObservationConvention + extends DefaultRSocketObservationConvention implements RSocketResponderObservationConvention { + + public DefaultRSocketResponderObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues tags = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + tags = tags.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return tags; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.response"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java new file mode 100644 index 000000000..fb80ea317 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java @@ -0,0 +1,208 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.docs.ObservationDocumentation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; +import reactor.util.context.ContextView; + +/** + * Tracing representation of a {@link RSocketProxy} for the requester. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationRequesterRSocketProxy extends RSocketProxy { + + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketRequesterObservationConvention observationConvention; + + public ObservationRequesterRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationRequesterRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketRequesterObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + return setObservation( + super::fireAndForget, + payload, + FrameType.REQUEST_FNF, + RSocketObservationDocumentation.RSOCKET_REQUESTER_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return setObservation( + super::requestResponse, + payload, + FrameType.REQUEST_RESPONSE, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_RESPONSE); + } + + Mono setObservation( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation observation) { + return Mono.deferContextual( + contextView -> observe(input, payload, frameType, observation, contextView)); + } + + private String route(Payload payload) { + if (payload.hasMetadata()) { + try { + ByteBuf extracted = + CompositeMetadataUtils.extract( + payload.sliceMetadata(), WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + final RoutingMetadata routingMetadata = new RoutingMetadata(extracted); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } catch (Exception e) { + + } + } + return null; + } + + private Mono observe( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs, + ContextView contextView) { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, payload.sliceMetadata(), frameType, route, RSocketContext.Side.REQUESTER); + Observation parentObservation = contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation observation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, observation); + observation.start(); + Payload newPayload = payload; + if (rSocketContext.modifiedPayload != null) { + newPayload = rSocketContext.modifiedPayload; + } + return input + .apply(newPayload) + .doOnError(observation::error) + .doFinally(signalType -> observation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, observation)); + } + + @Override + public Flux requestStream(Payload payload) { + return observationFlux( + super::requestStream, + payload, + FrameType.REQUEST_STREAM, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher inbound) { + return Flux.from(inbound) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + return observationFlux( + p -> super.requestChannel(flux.skip(1).startWith(p)), + firstPayload, + FrameType.REQUEST_CHANNEL, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_CHANNEL); + } + return flux; + }); + } + + private Flux observationFlux( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs) { + return Flux.deferContextual( + contextView -> { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + frameType, + route, + RSocketContext.Side.REQUESTER); + Observation parentObservation = + contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation newObservation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, newObservation); + newObservation.start(); + return input + .apply(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + }); + } + + private void setContextualName(FrameType frameType, String route, Observation newObservation) { + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(frameType.name() + " " + route); + } else { + newObservation.contextualName(frameType.name()); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java new file mode 100644 index 000000000..9ed27adf3 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java @@ -0,0 +1,179 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Tracing representation of a {@link RSocketProxy} for the responder. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationResponderRSocketProxy extends RSocketProxy { + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketResponderObservationConvention observationConvention; + + public ObservationResponderRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationResponderRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketResponderObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + // called on Netty EventLoop + // there can't be observation in thread local here + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_FNF, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation(RSocketObservationDocumentation.RSOCKET_RESPONDER_FNF, rSocketContext); + return super.fireAndForget(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + private Observation startObservation( + RSocketObservationDocumentation observation, RSocketContext rSocketContext) { + return observation.start( + this.observationConvention, + new DefaultRSocketResponderObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry); + } + + @Override + public Mono requestResponse(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_RESPONSE, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_RESPONSE, rSocketContext); + return super.requestResponse(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestStream(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, sliceMetadata, FrameType.REQUEST_STREAM, route, RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_STREAM, rSocketContext); + return super.requestStream(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + ByteBuf sliceMetadata = firstPayload.sliceMetadata(); + String route = route(firstPayload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + firstPayload, + firstPayload.sliceMetadata(), + FrameType.REQUEST_CHANNEL, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_CHANNEL, + rSocketContext); + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(rSocketContext.frameType.name() + " " + route); + } + return super.requestChannel(flux.skip(1).startWith(rSocketContext.modifiedPayload)) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite( + context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + return flux; + }); + } + + private String route(Payload payload, ByteBuf headers) { + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } + } catch (Exception e) { + + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java new file mode 100644 index 000000000..e5286a53f --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java @@ -0,0 +1,73 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.util.HashSet; +import java.util.Set; + +final class PayloadUtils { + + private PayloadUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + static CompositeByteBuf cleanTracingMetadata(Payload payload, Set fields) { + Set fieldsWithDefaultZipkin = new HashSet<>(fields); + fieldsWithDefaultZipkin.add(WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + final CompositeByteBuf metadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + if (payload.hasMetadata()) { + try { + final CompositeMetadata entries = new CompositeMetadata(payload.metadata(), false); + for (Entry entry : entries) { + if (!fieldsWithDefaultZipkin.contains(entry.getMimeType())) { + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + metadata, + ByteBufAllocator.DEFAULT, + entry.getMimeType(), + entry.getContent().retain()); + } + } + } catch (Exception e) { + + } + } + return metadata; + } + + static Payload payload(Payload payload, CompositeByteBuf metadata) { + final Payload newPayload; + try { + if (payload instanceof ByteBufPayload) { + newPayload = ByteBufPayload.create(payload.data().retain(), metadata); + } else { + newPayload = DefaultPayload.create(payload.data().retain(), metadata); + } + } finally { + payload.release(); + } + return newPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java new file mode 100644 index 000000000..8622cdfa5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java @@ -0,0 +1,76 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.lang.Nullable; +import io.micrometer.observation.Observation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; + +public class RSocketContext extends Observation.Context { + + final Payload payload; + + final ByteBuf metadata; + + final FrameType frameType; + + final String route; + + final Side side; + + Payload modifiedPayload; + + RSocketContext( + Payload payload, ByteBuf metadata, FrameType frameType, @Nullable String route, Side side) { + this.payload = payload; + this.metadata = metadata; + this.frameType = frameType; + this.route = route; + this.side = side; + } + + public enum Side { + REQUESTER, + RESPONDER + } + + public Payload getPayload() { + return payload; + } + + public ByteBuf getMetadata() { + return metadata; + } + + public FrameType getFrameType() { + return frameType; + } + + public String getRoute() { + return route; + } + + public Side getSide() { + return side; + } + + public Payload getModifiedPayload() { + return modifiedPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java new file mode 100644 index 000000000..1be6b4599 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java @@ -0,0 +1,232 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.docs.KeyName; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; +import io.micrometer.observation.docs.ObservationDocumentation; + +enum RSocketObservationDocumentation implements ObservationDocumentation { + + /** Observation created on the RSocket responder side. */ + RSOCKET_RESPONDER { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + }, + + /** Observation created on the RSocket requester side for Fire and Forget frame type. */ + RSOCKET_REQUESTER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Fire and Forget frame type. */ + RSOCKET_RESPONDER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Response frame type. */ + RSOCKET_REQUESTER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Response frame type. */ + RSOCKET_RESPONDER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Stream frame type. */ + RSOCKET_REQUESTER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Stream frame type. */ + RSOCKET_RESPONDER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Channel frame type. */ + RSOCKET_REQUESTER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Channel frame type. */ + RSOCKET_RESPONDER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }; + + enum RequesterTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + }, + + /** Name of the RSocket content type. */ + CONTENT_TYPE { + @Override + public String asString() { + return "rsocket.content-type"; + } + } + } + + enum ResponderTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java new file mode 100644 index 000000000..d795f81b5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket requester {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketRequesterObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java new file mode 100644 index 000000000..996267d4a --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java @@ -0,0 +1,131 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.TracingMetadataCodec; +import java.util.HashSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketRequesterTracingObservationHandler + implements TracingObservationHandler { + private static final Logger log = + LoggerFactory.getLogger(RSocketRequesterTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Setter setter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketRequesterTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Setter setter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.setter = setter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + @Override + public void onStart(RSocketContext context) { + Payload payload = context.payload; + Span.Builder spanBuilder = this.tracer.spanBuilder(); + Span parentSpan = getParentSpan(context); + if (parentSpan != null) { + spanBuilder.setParent(parentSpan.context()); + } + Span span = spanBuilder.kind(Span.Kind.PRODUCER).start(); + log.debug("Extracted result from context or thread local {}", span); + // TODO: newmetadata returns an empty composite byte buf + final CompositeByteBuf newMetadata = + PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields())); + TraceContext traceContext = span.context(); + if (this.isZipkinPropagationEnabled) { + injectDefaultZipkinRSocketHeaders(newMetadata, traceContext); + } + this.propagator.inject(traceContext, newMetadata, this.setter); + context.modifiedPayload = PayloadUtils.payload(payload, newMetadata); + getTracingContext(context).setSpan(span); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.name(context.getContextualName()).end(); + } + + private void injectDefaultZipkinRSocketHeaders( + CompositeByteBuf newMetadata, TraceContext traceContext) { + TracingMetadataCodec.Flags flags = + traceContext.sampled() == null + ? TracingMetadataCodec.Flags.UNDECIDED + : traceContext.sampled() + ? TracingMetadataCodec.Flags.SAMPLE + : TracingMetadataCodec.Flags.NOT_SAMPLE; + String traceId = traceContext.traceId(); + long[] traceIds = EncodingUtils.fromString(traceId); + long[] spanId = EncodingUtils.fromString(traceContext.spanId()); + long[] parentSpanId = EncodingUtils.fromString(traceContext.parentId()); + boolean isTraceId128Bit = traceIds.length == 2; + if (isTraceId128Bit) { + TracingMetadataCodec.encode128( + newMetadata.alloc(), + traceIds[0], + traceIds[1], + spanId[0], + EncodingUtils.fromString(traceContext.parentId())[0], + flags); + } else { + TracingMetadataCodec.encode64( + newMetadata.alloc(), traceIds[0], spanId[0], parentSpanId[0], flags); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java new file mode 100644 index 000000000..a5d6808bd --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket responder {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketResponderObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java new file mode 100644 index 000000000..e3975b577 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java @@ -0,0 +1,152 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TracingMetadata; +import io.rsocket.metadata.TracingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import java.util.HashSet; +import java.util.Iterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketResponderTracingObservationHandler + implements TracingObservationHandler { + + private static final Logger log = + LoggerFactory.getLogger(RSocketResponderTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Getter getter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketResponderTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Getter getter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.getter = getter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public void onStart(RSocketContext context) { + Span handle = consumerSpanBuilder(context.payload, context.metadata, context.frameType); + CompositeByteBuf bufs = + PayloadUtils.cleanTracingMetadata(context.payload, new HashSet<>(propagator.fields())); + context.modifiedPayload = PayloadUtils.payload(context.payload, bufs); + getTracingContext(context).setSpan(handle); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.end(); + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + private Span consumerSpanBuilder(Payload payload, ByteBuf headers, FrameType requestType) { + Span.Builder consumerSpanBuilder = consumerSpanBuilder(payload, headers); + log.debug("Extracted result from headers {}", consumerSpanBuilder); + String name = "handle"; + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + name = requestType.name() + " " + iterator.next(); + } + } catch (Exception e) { + + } + } + return consumerSpanBuilder.kind(Span.Kind.CONSUMER).name(name).start(); + } + + private Span.Builder consumerSpanBuilder(Payload payload, ByteBuf headers) { + if (this.isZipkinPropagationEnabled && payload.hasMetadata()) { + try { + ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + if (extract != null) { + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(extract); + Span.Builder builder = this.tracer.spanBuilder(); + String traceId = EncodingUtils.fromLong(tracingMetadata.traceId()); + long traceIdHigh = tracingMetadata.traceIdHigh(); + if (traceIdHigh != 0L) { + // ExtendedTraceId + traceId = EncodingUtils.fromLong(traceIdHigh) + traceId; + } + TraceContext.Builder parentBuilder = + this.tracer + .traceContextBuilder() + .sampled(tracingMetadata.isDebug() || tracingMetadata.isSampled()) + .traceId(traceId) + .spanId(EncodingUtils.fromLong(tracingMetadata.spanId())) + .parentId(EncodingUtils.fromLong(tracingMetadata.parentId())); + return builder.setParent(parentBuilder.build()); + } else { + return this.propagator.extract(headers, this.getter); + } + } catch (Exception e) { + + } + } + return this.propagator.extract(headers, this.getter); + } +} diff --git a/rsocket-spectator/build.gradle b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/package-info.java similarity index 62% rename from rsocket-spectator/build.gradle rename to rsocket-micrometer/src/main/java/io/rsocket/micrometer/package-info.java index f7eca3545..c95f2ce02 100644 --- a/rsocket-spectator/build.gradle +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,10 +14,12 @@ * limitations under the License. */ -dependencies { - compile project(':rsocket-core') - compile "com.netflix.spectator:spectator-api:$spectatorVersion" - compile "org.hdrhistogram:HdrHistogram:2.1.9" +/** + * Transparent metrics exposure for Micrometer. + * + * @see Micrometer + */ +@NonNullApi +package io.rsocket.micrometer; - testCompile project(':rsocket-test') -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptorTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptorTest.java new file mode 100644 index 000000000..4ff072252 --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptorTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static io.rsocket.plugins.DuplexConnectionInterceptor.Type.CLIENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; + +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.rsocket.DuplexConnection; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class MicrometerDuplexConnectionInterceptorTest { + + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("creates MicrometerDuplexConnection") + @Test + void apply() { + assertThat(new MicrometerDuplexConnectionInterceptor(meterRegistry).apply(CLIENT, delegate)) + .isInstanceOf(MicrometerDuplexConnection.class); + } + + @DisplayName("apply throws NullPointerException with null connectionType") + @Test + void applyNullConnectionType() { + assertThatNullPointerException() + .isThrownBy( + () -> new MicrometerDuplexConnectionInterceptor(meterRegistry).apply(null, delegate)) + .withMessage("connectionType must not be null"); + } + + @DisplayName("apply throws NullPointerException with null delegate") + @Test + void applyNullDelegate() { + assertThatNullPointerException() + .isThrownBy( + () -> new MicrometerDuplexConnectionInterceptor(meterRegistry).apply(CLIENT, null)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointer exception with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnectionInterceptor(null)) + .withMessage("meterRegistry must not be null"); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java new file mode 100644 index 000000000..7806200dd --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static io.rsocket.frame.FrameType.*; +import static io.rsocket.plugins.DuplexConnectionInterceptor.Type.CLIENT; +import static io.rsocket.plugins.DuplexConnectionInterceptor.Type.SERVER; +import static io.rsocket.test.TestFrames.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.*; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; + +final class MicrometerDuplexConnectionTest { + + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("constructor throws NullPointerException with null connectionType") + @Test + void constructorNullConnectionType() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnection(null, delegate, meterRegistry)) + .withMessage("connectionType must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null delegate") + @Test + void constructorNullDelegate() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnection(CLIENT, null, meterRegistry)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnection(CLIENT, delegate, null)) + .withMessage("meterRegistry must not be null"); + } + + @DisplayName("dispose gathers metrics") + @Test + void dispose() { + new MicrometerDuplexConnection( + CLIENT, delegate, meterRegistry, Tag.of("test-key", "test-value")) + .dispose(); + + assertThat( + meterRegistry + .get("rsocket.duplex.connection.dispose") + .tag("connection.type", CLIENT.name()) + .tag("test-key", "test-value") + .counter() + .count()) + .isEqualTo(1); + } + + @DisplayName("onClose gathers metrics") + @Test + void onClose() { + when(delegate.onClose()).thenReturn(Mono.empty()); + + new MicrometerDuplexConnection( + CLIENT, delegate, meterRegistry, Tag.of("test-key", "test-value")) + .onClose() + .subscribe(Operators.drainSubscriber()); + + assertThat( + meterRegistry + .get("rsocket.duplex.connection.close") + .tag("connection.type", CLIENT.name()) + .tag("test-key", "test-value") + .counter() + .count()) + .isEqualTo(1); + } + + @DisplayName("receive gathers metrics") + @Test + void receive() { + Flux frames = + Flux.just( + createTestCancelFrame(), + createTestErrorFrame(), + createTestKeepaliveFrame(), + createTestLeaseFrame(), + createTestMetadataPushFrame(), + createTestPayloadFrame(), + createTestRequestChannelFrame(), + createTestRequestFireAndForgetFrame(), + createTestRequestNFrame(), + createTestRequestResponseFrame(), + createTestRequestStreamFrame(), + createTestSetupFrame()); + + when(delegate.receive()).thenReturn(frames); + + new MicrometerDuplexConnection( + CLIENT, delegate, meterRegistry, Tag.of("test-key", "test-value")) + .receive() + .as(StepVerifier::create) + .expectNextCount(12) + .verifyComplete(); + + assertThat(findCounter(CLIENT, CANCEL).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, COMPLETE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, ERROR).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, KEEPALIVE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, LEASE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, METADATA_PUSH).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_CHANNEL).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_FNF).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_N).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_RESPONSE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_STREAM).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, SETUP).count()).isEqualTo(1); + } + + @DisplayName("send gathers metrics") + @SuppressWarnings("unchecked") + @Test + void send() { + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + doNothing().when(delegate).sendFrame(Mockito.anyInt(), captor.capture()); + + final MicrometerDuplexConnection micrometerDuplexConnection = + new MicrometerDuplexConnection( + SERVER, delegate, meterRegistry, Tag.of("test-key", "test-value")); + micrometerDuplexConnection.sendFrame(1, createTestCancelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestErrorFrame()); + micrometerDuplexConnection.sendFrame(1, createTestKeepaliveFrame()); + micrometerDuplexConnection.sendFrame(1, createTestLeaseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestMetadataPushFrame()); + micrometerDuplexConnection.sendFrame(1, createTestPayloadFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestChannelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestFireAndForgetFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestNFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestResponseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestStreamFrame()); + micrometerDuplexConnection.sendFrame(1, createTestSetupFrame()); + + StepVerifier.create(Flux.fromIterable(captor.getAllValues())) + .expectNextCount(12) + .verifyComplete(); + + assertThat(findCounter(SERVER, CANCEL).count()).isEqualTo(1); + assertThat(findCounter(SERVER, COMPLETE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, ERROR).count()).isEqualTo(1); + assertThat(findCounter(SERVER, KEEPALIVE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, LEASE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, METADATA_PUSH).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_CHANNEL).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_FNF).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_N).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_RESPONSE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_STREAM).count()).isEqualTo(1); + assertThat(findCounter(SERVER, SETUP).count()).isEqualTo(1); + } + + private Counter findCounter(Type connectionType, FrameType frameType) { + return meterRegistry + .get("rsocket.frame") + .tag("connection.type", connectionType.name()) + .tag("frame.type", frameType.name()) + .tag("test-key", "test-value") + .counter(); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketInterceptorTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketInterceptorTest.java new file mode 100644 index 000000000..196ee1aa6 --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketInterceptorTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; + +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.rsocket.RSocket; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class MicrometerRSocketInterceptorTest { + + private final RSocket delegate = mock(RSocket.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("creates MicrometerRSocket") + @Test + void apply() { + assertThat(new MicrometerRSocketInterceptor(meterRegistry).apply(delegate)) + .isInstanceOf(MicrometerRSocket.class); + } + + @DisplayName("apply throws NullPointerException with null delegate") + @Test + void applyNullDelegate() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocketInterceptor(meterRegistry).apply(null)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocketInterceptor(null)) + .withMessage("meterRegistry must not be null"); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketTest.java new file mode 100644 index 000000000..7317c5c59 --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketTest.java @@ -0,0 +1,146 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Timer; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.DefaultPayload; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.test.StepVerifier; + +final class MicrometerRSocketTest { + + private final RSocket delegate = mock(RSocket.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("constructor throws NullPointerException with null delegate") + @Test + void constructorNullDelegate() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocket(null, meterRegistry)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocket(delegate, null)) + .withMessage("meterRegistry must not be null"); + } + + @DisplayName("fireAndForget gathers metrics") + @Test + void fireAndForget() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.fireAndForget(payload)).thenReturn(Mono.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .fireAndForget(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("request.fnf", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("metadataPush gathers metrics") + @Test + void metadataPush() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.metadataPush(payload)).thenReturn(Mono.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .metadataPush(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("metadata.push", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("requestChannel gathers metrics") + @Test + void requestChannel() { + Mono payload = Mono.just(DefaultPayload.create("test-metadata", "test-data")); + when(delegate.requestChannel(payload)).thenReturn(Flux.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .requestChannel(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("request.channel", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("requestResponse gathers metrics") + @Test + void requestResponse() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.requestResponse(payload)).thenReturn(Mono.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .requestResponse(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findTimer("request.response", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("requestStream gathers metrics") + @Test + void requestStream() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.requestStream(payload)).thenReturn(Flux.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .requestStream(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("request.stream", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + private Counter findCounter(String interactionModel, SignalType signalType) { + return meterRegistry + .get(String.format("rsocket.%s", interactionModel)) + .tag("signal.type", signalType.name()) + .tag("test-key", "test-value") + .counter(); + } + + private Timer findTimer(String interactionModel, SignalType signalType) { + return meterRegistry + .get(String.format("rsocket.%s", interactionModel)) + .tag("signal.type", signalType.name()) + .tag("test-key", "test-value") + .timer(); + } +} diff --git a/rsocket-micrometer/src/test/resources/logback-test.xml b/rsocket-micrometer/src/test/resources/logback-test.xml new file mode 100644 index 000000000..56e2f9c9b --- /dev/null +++ b/rsocket-micrometer/src/test/resources/logback-test.xml @@ -0,0 +1,32 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + + + + + + diff --git a/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorFrameInterceptor.java b/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorFrameInterceptor.java deleted file mode 100644 index 3d93f4607..000000000 --- a/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorFrameInterceptor.java +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.spectator; - -import com.netflix.spectator.api.Counter; -import com.netflix.spectator.api.Registry; -import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import io.rsocket.FrameType; -import io.rsocket.plugins.DuplexConnectionInterceptor; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** An implementation of {@link DuplexConnectionInterceptor} that uses Spectator */ -public class SpectatorFrameInterceptor implements DuplexConnectionInterceptor { - private final Registry registry; - - public SpectatorFrameInterceptor(Registry registry) { - this.registry = registry; - } - - @Override - public DuplexConnection apply(Type type, DuplexConnection connection) { - return new DuplexConnection() { - Counter cancelCounter = registry.counter(FrameType.CANCEL.name(), type.name()); - Counter requestChannelCounter = - registry.counter(FrameType.REQUEST_CHANNEL.name(), type.name()); - Counter completeCounter = registry.counter(FrameType.COMPLETE.name(), type.name()); - Counter errorCounter = registry.counter(FrameType.ERROR.name(), type.name()); - Counter extCounter = registry.counter(FrameType.EXT.name(), type.name()); - Counter fireAndForgetCounter = - registry.counter(FrameType.FIRE_AND_FORGET.name(), type.name()); - Counter keepAliveCounter = registry.counter(FrameType.KEEPALIVE.name(), type.name()); - Counter leaseCounter = registry.counter(FrameType.LEASE.name(), type.name()); - Counter metadataPushCounter = registry.counter(FrameType.METADATA_PUSH.name(), type.name()); - Counter nextCounter = registry.counter(FrameType.NEXT.name(), type.name()); - Counter nextCompleteCounter = registry.counter(FrameType.NEXT_COMPLETE.name(), type.name()); - Counter payloadCounter = registry.counter(FrameType.PAYLOAD.name(), type.name()); - Counter requestNCounter = registry.counter(FrameType.REQUEST_N.name(), type.name()); - Counter requestResponseCounter = - registry.counter(FrameType.REQUEST_RESPONSE.name(), type.name()); - Counter requestStreamCounter = registry.counter(FrameType.REQUEST_STREAM.name(), type.name()); - Counter resumeCounter = registry.counter(FrameType.RESUME.name(), type.name()); - Counter resumeOkCounter = registry.counter(FrameType.RESUME_OK.name(), type.name()); - Counter setupCounter = registry.counter(FrameType.SETUP.name(), type.name()); - Counter undefinedCounter = registry.counter(FrameType.UNDEFINED.name(), type.name()); - - @Override - public Mono send(Publisher frame) { - return connection.send(Flux.from(frame).doOnNext(this::count)); - } - - @Override - public Mono sendOne(Frame frame) { - return Mono.defer( - () -> { - count(frame); - return connection.sendOne(frame); - }); - } - - @Override - public Flux receive() { - return connection.receive().doOnNext(this::count); - } - - @Override - public Mono close() { - return connection.close(); - } - - @Override - public Mono onClose() { - return connection.onClose(); - } - - @Override - public double availability() { - return connection.availability(); - } - - private void count(Frame frame) { - switch (frame.getType()) { - case CANCEL: - cancelCounter.increment(); - break; - case REQUEST_CHANNEL: - requestChannelCounter.increment(); - break; - case COMPLETE: - completeCounter.increment(); - break; - case ERROR: - errorCounter.increment(); - break; - case EXT: - extCounter.increment(); - break; - case FIRE_AND_FORGET: - fireAndForgetCounter.increment(); - break; - case KEEPALIVE: - keepAliveCounter.increment(); - break; - case LEASE: - leaseCounter.increment(); - break; - case METADATA_PUSH: - metadataPushCounter.increment(); - break; - case NEXT: - nextCounter.increment(); - break; - case NEXT_COMPLETE: - nextCompleteCounter.increment(); - break; - case PAYLOAD: - payloadCounter.increment(); - break; - case REQUEST_N: - requestNCounter.increment(); - break; - case REQUEST_RESPONSE: - requestResponseCounter.increment(); - break; - case REQUEST_STREAM: - requestStreamCounter.increment(); - break; - case RESUME: - resumeCounter.increment(); - break; - case RESUME_OK: - resumeOkCounter.increment(); - break; - case SETUP: - setupCounter.increment(); - break; - case UNDEFINED: - default: - undefinedCounter.increment(); - break; - } - } - }; - } -} diff --git a/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorRSocket.java b/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorRSocket.java deleted file mode 100644 index 79da95698..000000000 --- a/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorRSocket.java +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.spectator; - -import com.netflix.spectator.api.Counter; -import com.netflix.spectator.api.Registry; -import com.netflix.spectator.api.Timer; -import io.rsocket.Payload; -import io.rsocket.RSocket; -import java.util.concurrent.TimeUnit; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** Wraps a {@link RSocket} with counters */ -public class SpectatorRSocket implements RSocket { - private final RSocket delegate; - - private Counter fireAndForgetErrors; - private Counter fireAndForgetCanceled; - private Counter fireAndForgetTotal; - private Timer fireAndForgetTimer; - - private Counter requestResponseErrors; - private Counter requestResponseCanceled; - private Counter requestResponseTotal; - private Timer requestResponseTimer; - - private Counter requestStreamErrors; - private Counter requestStreamCanceled; - private Counter requestStreamTotal; - - private Counter requestChannelErrors; - private Counter requestChannelCanceled; - private Counter requestChannelTotal; - - private Counter metadataPushErrors; - private Counter metadataPushCanceled; - private Counter metadataPushTotal; - private Timer metadataPushTimer; - - public SpectatorRSocket(Registry registry, RSocket delegate, String... tags) { - this.delegate = delegate; - - this.fireAndForgetErrors = - registry.counter("reactiveSocketStats", concatenate(tags, "fireAndForget", "errors")); - this.fireAndForgetCanceled = - registry.counter("reactiveSocketStats", concatenate(tags, "fireAndForget", "canceled")); - this.fireAndForgetTotal = - registry.counter("reactiveSocketStats", concatenate(tags, "fireAndForget", "total")); - this.fireAndForgetTimer = - registry.timer("reactiveSocketStats", concatenate(tags, "fireAndForget", "timer")); - - this.requestResponseErrors = - registry.counter("reactiveSocketStats", concatenate(tags, "requestResponse", "errors")); - this.requestResponseCanceled = - registry.counter("reactiveSocketStats", concatenate(tags, "requestResponse", "canceled")); - this.requestResponseTotal = - registry.counter("reactiveSocketStats", concatenate(tags, "requestResponse", "total")); - this.requestResponseTimer = - registry.timer("reactiveSocketStats", concatenate(tags, "requestResponse", "timer")); - - this.requestStreamErrors = - registry.counter("reactiveSocketStats", concatenate(tags, "requestStream", "errors")); - this.requestStreamCanceled = - registry.counter("reactiveSocketStats", concatenate(tags, "requestStream", "canceled")); - this.requestStreamTotal = - registry.counter("reactiveSocketStats", concatenate(tags, "requestStream", "total")); - - this.requestChannelErrors = - registry.counter("reactiveSocketStats", concatenate(tags, "requestChannel", "errors")); - this.requestChannelCanceled = - registry.counter("reactiveSocketStats", concatenate(tags, "requestChannel", "canceled")); - this.requestChannelTotal = - registry.counter("reactiveSocketStats", concatenate(tags, "requestChannel", "total")); - - this.metadataPushErrors = - registry.counter("reactiveSocketStats", concatenate(tags, "metadataPush", "errors")); - this.metadataPushCanceled = - registry.counter("reactiveSocketStats", concatenate(tags, "metadataPush", "canceled")); - this.metadataPushTotal = - registry.counter("reactiveSocketStats", concatenate(tags, "metadataPush", "total")); - this.metadataPushTimer = - registry.timer("reactiveSocketStats", concatenate(tags, "metadataPush", "timer")); - } - - private static String[] concatenate(String[] a, String... b) { - if (a == null || a.length == 0) { - return b; - } - - int aLen = a.length; - int bLen = b.length; - - String[] c = new String[aLen + bLen]; - System.arraycopy(a, 0, c, 0, aLen); - System.arraycopy(b, 0, c, aLen, bLen); - - return c; - } - - @Override - public Mono fireAndForget(Payload payload) { - return Mono.defer( - () -> { - long start = System.nanoTime(); - return delegate - .fireAndForget(payload) - .doFinally( - signalType -> { - fireAndForgetTimer.record(System.nanoTime() - start, TimeUnit.NANOSECONDS); - fireAndForgetTotal.increment(); - - switch (signalType) { - case CANCEL: - fireAndForgetCanceled.increment(); - break; - case ON_ERROR: - fireAndForgetErrors.increment(); - } - }); - }); - } - - @Override - public Mono requestResponse(Payload payload) { - return Mono.defer( - () -> { - long start = System.nanoTime(); - return delegate - .requestResponse(payload) - .doFinally( - signalType -> { - requestResponseTimer.record(System.nanoTime() - start, TimeUnit.NANOSECONDS); - requestResponseTotal.increment(); - - switch (signalType) { - case CANCEL: - requestResponseCanceled.increment(); - break; - case ON_ERROR: - requestResponseErrors.increment(); - } - }); - }); - } - - @Override - public Flux requestStream(Payload payload) { - return Flux.defer( - () -> - delegate - .requestStream(payload) - .doFinally( - signalType -> { - requestStreamTotal.increment(); - - switch (signalType) { - case CANCEL: - requestStreamCanceled.increment(); - break; - case ON_ERROR: - requestStreamErrors.increment(); - } - })); - } - - @Override - public Flux requestChannel(Publisher payloads) { - return Flux.defer( - () -> - delegate - .requestChannel(payloads) - .doFinally( - signalType -> { - requestChannelTotal.increment(); - - switch (signalType) { - case CANCEL: - requestChannelCanceled.increment(); - break; - case ON_ERROR: - requestChannelErrors.increment(); - } - })); - } - - @Override - public Mono metadataPush(Payload payload) { - return Mono.defer( - () -> { - long start = System.nanoTime(); - return delegate - .metadataPush(payload) - .doFinally( - signalType -> { - metadataPushTimer.record(System.nanoTime() - start, TimeUnit.NANOSECONDS); - metadataPushTotal.increment(); - - switch (signalType) { - case CANCEL: - metadataPushCanceled.increment(); - break; - case ON_ERROR: - metadataPushErrors.increment(); - } - }); - }); - } - - @Override - public Mono close() { - return delegate.close(); - } - - @Override - public Mono onClose() { - return delegate.onClose(); - } - - @Override - public double availability() { - return delegate.availability(); - } -} diff --git a/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorRSocketInterceptor.java b/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorRSocketInterceptor.java deleted file mode 100644 index ceede8338..000000000 --- a/rsocket-spectator/src/main/java/io/rsocket/spectator/SpectatorRSocketInterceptor.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.spectator; - -import com.netflix.spectator.api.Registry; -import io.rsocket.RSocket; -import io.rsocket.plugins.RSocketInterceptor; - -/** Interceptor that wraps a {@link RSocket} with a {@link SpectatorRSocket} */ -public class SpectatorRSocketInterceptor implements RSocketInterceptor { - private static final String[] EMPTY = new String[0]; - private final Registry registry; - private final String[] tags; - - public SpectatorRSocketInterceptor(Registry registry, String... tags) { - this.registry = registry; - this.tags = tags; - } - - public SpectatorRSocketInterceptor(Registry registry) { - this(registry, EMPTY); - } - - @Override - public RSocket apply(RSocket reactiveSocket) { - return new SpectatorRSocket(registry, reactiveSocket, tags); - } -} diff --git a/rsocket-test/build.gradle b/rsocket-test/build.gradle index 10705bd19..bcdf88f28 100644 --- a/rsocket-test/build.gradle +++ b/rsocket-test/build.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,11 +14,28 @@ * limitations under the License. */ +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + dependencies { - compile project(':rsocket-core') - compile "junit:junit:4.12" - compile "org.mockito:mockito-core:2.10.0" - compile "org.hamcrest:hamcrest-library:1.3" - compile "org.hdrhistogram:HdrHistogram:2.1.9" - compile "io.projectreactor:reactor-test:3.1.0.RELEASE" + api project(':rsocket-core') + api 'org.hdrhistogram:HdrHistogram' + api 'org.junit.jupiter:junit-jupiter-api' + + implementation 'io.projectreactor:reactor-test' + implementation 'org.assertj:assertj-core' + implementation 'org.mockito:mockito-core' + implementation 'org.awaitility:awaitility' + implementation 'org.slf4j:slf4j-api' } + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.test") + } +} + +description = 'Test utilities for RSocket projects' diff --git a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java index 5c8e64264..e773b4a0d 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java +++ b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java @@ -1,36 +1,50 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.test; -import static org.junit.Assert.assertEquals; +import static org.assertj.core.api.Assertions.assertThat; import io.rsocket.Payload; -import io.rsocket.util.PayloadImpl; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import reactor.core.publisher.Flux; public abstract class BaseClientServerTest> { - @Rule public final T setup = createClientServer(); + public final T setup = createClientServer(); protected abstract T createClientServer(); - @Test(timeout = 10000) + @BeforeEach + public void init() { + setup.init(); + } + + @AfterEach + public void teardown() { + setup.tearDown(); + } + + @Test + @Timeout(10000) public void testFireNForget10() { long outputCount = Flux.range(1, 10) @@ -39,22 +53,23 @@ public void testFireNForget10() { .count() .block(); - assertEquals(0, outputCount); + assertThat(outputCount).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testPushMetadata10() { long outputCount = Flux.range(1, 10) - .flatMap(i -> setup.getRSocket().metadataPush(new PayloadImpl("", "metadata"))) + .flatMap(i -> setup.getRSocket().metadataPush(DefaultPayload.create("", "metadata"))) .doOnError(Throwable::printStackTrace) .count() .block(); - assertEquals(0, outputCount); + assertThat(outputCount).isZero(); } - @Test(timeout = 10000) + @Test // (timeout = 10000) public void testRequestResponse1() { long outputCount = Flux.range(1, 1) @@ -64,10 +79,11 @@ public void testRequestResponse1() { .count() .block(); - assertEquals(1, outputCount); + assertThat(outputCount).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestResponse10() { long outputCount = Flux.range(1, 10) @@ -77,7 +93,7 @@ public void testRequestResponse10() { .count() .block(); - assertEquals(10, outputCount); + assertThat(outputCount).isEqualTo(10); } private Payload testPayload(int metadataPresent) { @@ -93,10 +109,11 @@ private Payload testPayload(int metadataPresent) { metadata = "metadata"; break; } - return new PayloadImpl("hello", metadata); + return DefaultPayload.create("hello", metadata); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestResponse100() { long outputCount = Flux.range(1, 100) @@ -106,10 +123,11 @@ public void testRequestResponse100() { .count() .block(); - assertEquals(100, outputCount); + assertThat(outputCount).isEqualTo(100); } - @Test(timeout = 20000) + @Test + @Timeout(20000) public void testRequestResponse10_000() { long outputCount = Flux.range(1, 10_000) @@ -119,19 +137,31 @@ public void testRequestResponse10_000() { .count() .block(); - assertEquals(10_000, outputCount); + assertThat(outputCount).isEqualTo(10_000); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStream() { Flux publisher = setup.getRSocket().requestStream(testPayload(3)); long count = publisher.take(5).count().block(); - assertEquals(5, count); + assertThat(count).isEqualTo(5); } - @Test(timeout = 10000) + @Test + @Timeout(10000) + public void testRequestStreamAll() { + Flux publisher = setup.getRSocket().requestStream(testPayload(3)); + + long count = publisher.count().block(); + + assertThat(count).isEqualTo(10000); + } + + @Test + @Timeout(10000) public void testRequestStreamWithRequestN() { CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); ts.expect(5); @@ -139,16 +169,17 @@ public void testRequestStreamWithRequestN() { setup.getRSocket().requestStream(testPayload(3)).subscribe(ts); ts.await(); - assertEquals(5, ts.count()); + assertThat(ts.count()).isEqualTo(5); ts.expect(5); ts.await(); ts.cancel(); - assertEquals(10, ts.count()); + assertThat(ts.count()).isEqualTo(10); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testRequestStreamWithDelayedRequestN() { CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); @@ -157,35 +188,37 @@ public void testRequestStreamWithDelayedRequestN() { ts.expect(5); ts.await(); - assertEquals(5, ts.count()); + assertThat(ts.count()).isEqualTo(5); ts.expect(5); ts.await(); ts.cancel(); - assertEquals(10, ts.count()); + assertThat(ts.count()).isEqualTo(10); } - @Test(timeout = 10000) - @Ignore + @Test + @Timeout(10000) public void testChannel0() { Flux publisher = setup.getRSocket().requestChannel(Flux.empty()); long count = publisher.count().block(); - assertEquals(0, count); + assertThat(count).isZero(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel1() { Flux publisher = setup.getRSocket().requestChannel(Flux.just(testPayload(0))); long count = publisher.count().block(); - assertEquals(1, count); + assertThat(count).isOne(); } - @Test(timeout = 10000) + @Test + @Timeout(10000) public void testChannel3() { Flux publisher = setup @@ -194,6 +227,48 @@ public void testChannel3() { long count = publisher.count().block(); - assertEquals(3, count); + assertThat(count).isEqualTo(3); + } + + @Test + @Timeout(10000) + public void testChannel512() { + Flux payloads = Flux.range(1, 512).map(i -> DefaultPayload.create("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(512); + } + + @Test + @Timeout(30000) + public void testChannel20_000() { + Flux payloads = Flux.range(1, 20_000).map(i -> DefaultPayload.create("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(20_000); + } + + @Test + @Timeout(60_000) + public void testChannel200_000() { + Flux payloads = Flux.range(1, 200_000).map(i -> DefaultPayload.create("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(200_000); + } + + @Test + @Timeout(60_000) + @Disabled + public void testChannel2_000_000() { + AtomicInteger counter = new AtomicInteger(0); + + Flux payloads = Flux.range(1, 2_000_000).map(i -> DefaultPayload.create("hello " + i)); + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(2_000_000); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java new file mode 100644 index 000000000..d065f3d71 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.presentation.StandardRepresentation; + +public final class ByteBufRepresentation extends StandardRepresentation { + + @Override + protected String fallbackToStringOf(Object object) { + if (object instanceof ByteBuf) { + try { + String normalBufferString = object.toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 256) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } + } catch (IllegalReferenceCountException e) { + // noops + } + } + + return super.fallbackToStringOf(object); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java index 03c63b791..1d6b7f69e 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java +++ b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,24 +18,25 @@ import io.rsocket.Closeable; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; -import org.junit.rules.ExternalResource; -import org.junit.runner.Description; -import org.junit.runners.model.Statement; import reactor.core.publisher.Mono; -public class ClientSetupRule extends ExternalResource { +public class ClientSetupRule { + private static final String data = "hello world"; + private static final String metadata = "metadata"; private Supplier addressSupplier; private BiFunction clientConnector; private Function serverInit; private RSocket client; + private S server; public ClientSetupRule( Supplier addressSupplier, @@ -45,36 +46,36 @@ public ClientSetupRule( this.serverInit = address -> - RSocketFactory.receive() - .acceptor((setup, sendingSocket) -> Mono.just(new TestRSocket())) - .transport(serverTransportSupplier.apply(address)) - .start() + RSocketServer.create((setup, rsocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) .block(); this.clientConnector = (address, server) -> - RSocketFactory.connect() - .transport(clientTransportSupplier.apply(address, server)) - .start() + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) .doOnError(Throwable::printStackTrace) .block(); } - @Override - public Statement apply(Statement base, Description description) { - return new Statement() { - @Override - public void evaluate() throws Throwable { - T address = addressSupplier.get(); - S server = serverInit.apply(address); - client = clientConnector.apply(address, server); - base.evaluate(); - server.close().block(); - } - }; + public void init() { + T address = addressSupplier.get(); + S server = serverInit.apply(address); + client = clientConnector.apply(address, server); + } + + public void tearDown() { + server.dispose(); } public RSocket getRSocket() { return client; } + + public String expectedPayloadData() { + return data; + } + + public String expectedPayloadMetadata() { + return metadata; + } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/CountdownBaseSubscriber.java b/rsocket-test/src/main/java/io/rsocket/test/CountdownBaseSubscriber.java index 1f7d79ba5..8fb948e9f 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/CountdownBaseSubscriber.java +++ b/rsocket-test/src/main/java/io/rsocket/test/CountdownBaseSubscriber.java @@ -1,3 +1,19 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.test; import io.rsocket.Payload; diff --git a/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..46e807b09 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,294 @@ +package io.rsocket.test; + +import static java.util.concurrent.locks.LockSupport.parkNanos; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + final Duration awaitZeroRefCntDuration; + + final String tag; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { + this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java b/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java new file mode 100644 index 000000000..3830ec1bc --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java @@ -0,0 +1,17 @@ +package io.rsocket.test; + +import java.lang.annotation.*; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +/** + * {@code @PerfTest} is used to signal that the annotated test class or method is performance test, + * and is disabled unless enabled via setting the {@code TEST_PERF_ENABLED} environment variable to + * {@code true}. + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@EnabledIfEnvironmentVariable(named = "TEST_PERF_ENABLED", matches = "(?i)true") +@Test +public @interface PerfTest {} diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java index a2b8ec334..14740950a 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,9 +18,11 @@ import io.rsocket.Payload; import io.rsocket.RSocket; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.ByteBufPayload; import java.time.Duration; +import java.util.function.BiFunction; import org.HdrHistogram.Recorder; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -31,7 +33,7 @@ public class PingClient { public PingClient(Mono client) { this.client = client; - this.payload = new PayloadImpl("hello"); + this.payload = ByteBufPayload.create("hello"); } public Recorder startTracker(Duration interval) { @@ -49,23 +51,38 @@ public Recorder startTracker(Duration interval) { return histogram; } - public Flux startPingPong(int count, final Recorder histogram) { - return client - .flatMapMany( + public Flux requestResponsePingPong(int count, final Recorder histogram) { + return pingPong(RSocket::requestResponse, count, histogram); + } + + public Flux requestStreamPingPong(int count, final Recorder histogram) { + return pingPong(RSocket::requestStream, count, histogram); + } + + Flux pingPong( + BiFunction> interaction, + int count, + final Recorder histogram) { + return Flux.usingWhen( + client, rsocket -> Flux.range(1, count) .flatMap( i -> { long start = System.nanoTime(); - return rsocket - .requestResponse(payload) + return Flux.from(interaction.apply(rsocket, payload.retain())) + .doOnNext(Payload::release) .doFinally( signalType -> { long diff = System.nanoTime() - start; histogram.recordValue(diff); }); }, - 16)) + 64), + rsocket -> { + rsocket.dispose(); + return rsocket.onClose(); + }) .doOnError(Throwable::printStackTrace); } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java index f349a978c..47f40a59d 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java +++ b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,13 +16,13 @@ package io.rsocket.test; -import io.rsocket.AbstractRSocket; import io.rsocket.ConnectionSetupPayload; import io.rsocket.Payload; import io.rsocket.RSocket; import io.rsocket.SocketAcceptor; -import io.rsocket.util.PayloadImpl; +import io.rsocket.util.ByteBufPayload; import java.util.concurrent.ThreadLocalRandom; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; public class PingHandler implements SocketAcceptor { @@ -32,20 +32,27 @@ public class PingHandler implements SocketAcceptor { public PingHandler() { byte[] data = new byte[1024]; ThreadLocalRandom.current().nextBytes(data); - pong = new PayloadImpl(data); + pong = ByteBufPayload.create(data); } public PingHandler(byte[] data) { - pong = new PayloadImpl(data); + pong = ByteBufPayload.create(data); } @Override public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { return Mono.just( - new AbstractRSocket() { + new RSocket() { @Override public Mono requestResponse(Payload payload) { - return Mono.just(pong); + payload.release(); + return Mono.just(pong.retain()); + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.range(0, 100).map(v -> pong.retain()); } }); } diff --git a/rsocket-test/src/main/java/io/rsocket/test/SlowTest.java b/rsocket-test/src/main/java/io/rsocket/test/SlowTest.java new file mode 100644 index 000000000..596cc0ffb --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/SlowTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +/** + * {@code @SlowTest} is used to signal that the annotated test class or test method is slow running + * and will be disabled unless enabled via setting the {@code TEST_SLOW_ENABLED} environment + * variable to {@code true}. + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@EnabledIfEnvironmentVariable(named = "TEST_SLOW_ENABLED", matches = "(?i)true") +@Test +public @interface SlowTest {} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java new file mode 100644 index 000000000..57a00e229 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java @@ -0,0 +1,166 @@ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.PayloadFrameCodec; +import java.net.SocketAddress; +import java.util.function.BiFunction; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class TestDuplexConnection implements DuplexConnection { + + final ByteBufAllocator allocator; + final Sinks.Many inbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.Many outbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.One close = Sinks.one(); + + public TestDuplexConnection( + CoreSubscriber outboundSubscriber, boolean trackLeaks) { + this.outbound.asFlux().subscribe(outboundSubscriber); + this.allocator = + trackLeaks + ? LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT) + : ByteBufAllocator.DEFAULT; + } + + @Override + public void dispose() { + this.inbound.tryEmitComplete(); + this.outbound.tryEmitComplete(); + this.close.tryEmitEmpty(); + } + + @Override + public Mono onClose() { + return this.close.asMono(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) {} + + @Override + public Flux receive() { + return this.inbound + .asFlux() + .transform( + Operators.lift( + (BiFunction< + Scannable, + CoreSubscriber, + CoreSubscriber>) + ByteBufReleaserOperator::create)); + } + + @Override + public ByteBufAllocator alloc() { + return this.allocator; + } + + @Override + public SocketAddress remoteAddress() { + return new SocketAddress() { + @Override + public String toString() { + return "Test"; + } + }; + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + this.outbound.tryEmitNext(frame); + } + + public void sendPayloadFrame( + int streamId, ByteBuf data, @Nullable ByteBuf metadata, boolean complete) { + sendFrame( + streamId, + PayloadFrameCodec.encode(this.allocator, streamId, false, complete, true, metadata, data)); + } + + static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + static CoreSubscriber create( + Scannable scannable, CoreSubscriber actual) { + return new ByteBufReleaserOperator(actual); + } + + final CoreSubscriber actual; + + Subscription s; + + public ByteBufReleaserOperator(CoreSubscriber actual) { + this.actual = actual; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + this.actual.onNext(buf); + buf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java new file mode 100644 index 000000000..1e66abc5e --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java @@ -0,0 +1,108 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.frame.*; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; + +/** Test instances of all frame types. */ +public final class TestFrames { + private static final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private static final Payload emptyPayload = DefaultPayload.create(Unpooled.EMPTY_BUFFER); + + private TestFrames() {} + + /** @return {@link ByteBuf} representing test instance of Cancel frame */ + public static ByteBuf createTestCancelFrame() { + return CancelFrameCodec.encode(allocator, 1); + } + + /** @return {@link ByteBuf} representing test instance of Error frame */ + public static ByteBuf createTestErrorFrame() { + return ErrorFrameCodec.encode(allocator, 1, new RuntimeException()); + } + + /** @return {@link ByteBuf} representing test instance of Extension frame */ + public static ByteBuf createTestExtensionFrame() { + return ExtensionFrameCodec.encode( + allocator, 1, 1, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Keep-Alive frame */ + public static ByteBuf createTestKeepaliveFrame() { + return KeepAliveFrameCodec.encode(allocator, false, 1, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Lease frame */ + public static ByteBuf createTestLeaseFrame() { + return LeaseFrameCodec.encode(allocator, 1, 1, null); + } + + /** @return {@link ByteBuf} representing test instance of Metadata-Push frame */ + public static ByteBuf createTestMetadataPushFrame() { + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Payload frame */ + public static ByteBuf createTestPayloadFrame() { + return PayloadFrameCodec.encode(allocator, 1, false, true, false, null, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Request-Channel frame */ + public static ByteBuf createTestRequestChannelFrame() { + return RequestChannelFrameCodec.encode( + allocator, 1, false, false, 1, null, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Fire-and-Forget frame */ + public static ByteBuf createTestRequestFireAndForgetFrame() { + return RequestFireAndForgetFrameCodec.encode(allocator, 1, false, null, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Request-N frame */ + public static ByteBuf createTestRequestNFrame() { + return RequestNFrameCodec.encode(allocator, 1, 1); + } + + /** @return {@link ByteBuf} representing test instance of Request-Response frame */ + public static ByteBuf createTestRequestResponseFrame() { + return RequestResponseFrameCodec.encodeReleasingPayload(allocator, 1, emptyPayload); + } + + /** @return {@link ByteBuf} representing test instance of Request-Stream frame */ + public static ByteBuf createTestRequestStreamFrame() { + return RequestStreamFrameCodec.encodeReleasingPayload(allocator, 1, 1L, emptyPayload); + } + + /** @return {@link ByteBuf} representing test instance of Setup frame */ + public static ByteBuf createTestSetupFrame() { + return SetupFrameCodec.encode( + allocator, + false, + 1, + 1, + Unpooled.EMPTY_BUFFER, + "metadataType", + "dataType", + EmptyPayload.INSTANCE); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java index 5f75400e8..1b294e394 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,38 +16,94 @@ package io.rsocket.test; -import io.rsocket.AbstractRSocket; +import static java.util.concurrent.locks.LockSupport.parkNanos; + import io.rsocket.Payload; -import io.rsocket.util.PayloadImpl; +import io.rsocket.RSocket; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -public class TestRSocket extends AbstractRSocket { +public class TestRSocket implements RSocket { + private final String data; + private final String metadata; + + private final AtomicLong observedInteractions = new AtomicLong(); + private final AtomicLong activeInteractions = new AtomicLong(); + + public TestRSocket(String data, String metadata) { + this.data = data; + this.metadata = metadata; + } @Override public Mono requestResponse(Payload payload) { - return Mono.just(new PayloadImpl("hello world", "metadata")); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.just(ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Flux requestStream(Payload payload) { - return Flux.range(1, 10_000).flatMap(l -> requestResponse(payload)); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Flux.range(1, 10_000) + .map(l -> ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Mono metadataPush(Payload payload) { - return Mono.empty(); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Mono fireAndForget(Payload payload) { - return Mono.empty(); + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); } @Override public Flux requestChannel(Publisher payloads) { - // TODO is defensive copy neccesary? - return Flux.from(payloads).map(p -> new PayloadImpl(p.getDataUtf8(), p.getMetadataUtf8())); + activeInteractions.getAndIncrement(); + observedInteractions.getAndIncrement(); + return Flux.from(payloads).doFinally(__ -> activeInteractions.getAndDecrement()); + } + + public boolean awaitAllInteractionTermination(Duration duration) { + long end = duration.plusNanos(System.nanoTime()).toNanos(); + long activeNow; + while ((activeNow = activeInteractions.get()) > 0) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return activeNow == 0; + } + + public boolean awaitUntilObserved(int interactions, Duration duration) { + long end = System.nanoTime() + duration.toNanos(); + long observed; + while ((observed = observedInteractions.get()) < interactions) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return observed >= interactions; } } diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestSubscriber.java b/rsocket-test/src/main/java/io/rsocket/test/TestSubscriber.java index 973440df3..62b6c242b 100644 --- a/rsocket-test/src/main/java/io/rsocket/test/TestSubscriber.java +++ b/rsocket-test/src/main/java/io/rsocket/test/TestSubscriber.java @@ -1,3 +1,19 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.test; import static org.mockito.ArgumentMatchers.any; diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java new file mode 100644 index 000000000..1fcca97db --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -0,0 +1,984 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiFunction; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.Fuseable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.Logger; +import reactor.util.Loggers; + +public interface TransportTest { + + Logger logger = Loggers.getLogger(TransportTest.class); + + String MOCK_DATA = "test-data"; + String MOCK_METADATA = "metadata"; + String LARGE_DATA = read("words.shakespeare.txt.gz"); + Payload LARGE_PAYLOAD = ByteBufPayload.create(LARGE_DATA, LARGE_DATA); + + static String read(String resourceName) { + try (BufferedReader br = + new BufferedReader( + new InputStreamReader( + new GZIPInputStream( + TransportTest.class.getClassLoader().getResourceAsStream(resourceName))))) { + + return br.lines().map(String::toLowerCase).collect(Collectors.joining("\n\r")); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @BeforeEach + default void setup() { + Hooks.onOperatorDebug(); + } + + @AfterEach + default void close() { + try { + logger.debug("------------------Awaiting communication to finish------------------"); + getTransportPair().responder.awaitAllInteractionTermination(getTimeout()); + logger.debug("---------------------Disposing Client And Server--------------------"); + getTransportPair().dispose(); + getTransportPair().awaitClosed(getTimeout()); + logger.debug("------------------------Disposing Schedulers-------------------------"); + Schedulers.parallel().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.boundedElastic().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.single().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + logger.debug("---------------------------Leaks Checking----------------------------"); + RuntimeException throwable = + new RuntimeException() { + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } + + @Override + public String getMessage() { + return Arrays.toString(getSuppressed()); + } + }; + + try { + getTransportPair().byteBufAllocator2.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + try { + getTransportPair().byteBufAllocator1.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + if (throwable.getSuppressed().length > 0) { + throw throwable; + } + } finally { + Hooks.resetOnOperatorDebug(); + Schedulers.resetOnHandleError(); + } + } + + default Payload createTestPayload(int metadataPresent) { + String metadata1; + + switch (metadataPresent % 5) { + case 0: + metadata1 = null; + break; + case 1: + metadata1 = ""; + break; + default: + metadata1 = MOCK_METADATA; + break; + } + String metadata = metadata1; + + return ByteBufPayload.create(MOCK_DATA, metadata); + } + + @DisplayName("makes 10 fireAndForget requests") + @Test + default void fireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(createTestPayload(i))) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + @DisplayName("makes 10 fireAndForget with Large Payload in Requests") + @Test + default void largePayloadFireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD.retain())) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + default RSocket getClient() { + return getTransportPair().getClient(); + } + + Duration getTimeout(); + + TransportPair getTransportPair(); + + @DisplayName("makes 10 metadataPush requests") + @Test + default void metadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", "test-metadata"))) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + @DisplayName("makes 10 metadataPush with Large Metadata in requests") + @Test + default void largePayloadMetadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", LARGE_DATA))) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 0 payloads") + @Test + default void requestChannel0() { + getClient() + .requestChannel(Flux.empty()) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Empty Source")) + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 1 payloads") + @Test + default void requestChannel1() { + getClient() + .requestChannel(Mono.just(createTestPayload(0))) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(1)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 200,000 payloads") + @Test + default void requestChannel200_000() { + Flux payloads = Flux.range(0, 200_000).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .limitRate(8) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(200_000)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 50 large payloads") + @Test + default void largePayloadRequestChannel50() { + Flux payloads = Flux.range(0, 50).map(__ -> LARGE_PAYLOAD.retain()); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(50)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 20,000 payloads") + @Test + default void requestChannel20_000() { + Flux payloads = Flux.range(0, 20_000).map(metadataPresent -> createTestPayload(7)); + + getClient() + .requestChannel(payloads) + .doOnNext(this::assertChannelPayload) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(20_000)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 2,000,000 payloads") + @SlowTest + default void requestChannel2_000_000() { + Flux payloads = Flux.range(0, 2_000_000).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .limitRate(8) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(2_000_000)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 3 payloads") + @Test + default void requestChannel3() { + AtomicLong requested = new AtomicLong(); + Flux payloads = + Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .as(publisher -> StepVerifier.create(publisher, 3)) + .thenConsumeWhile(new PayloadPredicate(3)) + .expectComplete() + .verify(getTimeout()); + + Assertions.assertThat(requested.get()).isEqualTo(3L); + } + + @DisplayName("makes 1 requestChannel request with 256 payloads") + @Test + default void requestChannel256() { + AtomicInteger counter = new AtomicInteger(); + Flux payloads = + Flux.defer( + () -> { + final int subscription = counter.getAndIncrement(); + return Flux.range(0, 256) + .map(i -> "S{" + subscription + "}: Data{" + i + "}") + .map(data -> ByteBufPayload.create(data)); + }); + final Scheduler scheduler = Schedulers.fromExecutorService(Executors.newFixedThreadPool(12)); + + try { + Flux.range(0, 1024) + .flatMap(v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(scheduler), 12) + .blockLast(); + } finally { + scheduler.disposeGracefully().block(); + } + } + + default void check(Flux payloads) { + getClient() + .requestChannel(payloads) + .doOnNext(ReferenceCounted::release) + .limitRate(8) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(256)) + .as("expected 256 items") + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestResponse request") + @Test + default void requestResponse1() { + getClient() + .requestResponse(createTestPayload(1)) + .doOnNext(this::assertPayload) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10 requestResponse requests") + @Test + default void requestResponse10() { + Flux.range(1, 10) + .flatMap( + i -> + getClient() + .requestResponse(createTestPayload(i)) + .doOnNext(v -> assertPayload(v)) + .doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(10) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 100 requestResponse requests") + @Test + default void requestResponse100() { + Flux.range(1, 100) + .flatMap(i -> getClient().requestResponse(createTestPayload(i)).doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 50 requestResponse requests") + @Test + default void largePayloadRequestResponse50() { + Flux.range(1, 50) + .flatMap( + i -> getClient().requestResponse(LARGE_PAYLOAD.retain()).doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(50) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10,000 requestResponse requests") + @Test + default void requestResponse10_000() { + Flux.range(1, 10_000) + .flatMap(i -> getClient().requestResponse(createTestPayload(i)).doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(10_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and receives 10,000 responses") + @Test + default void requestStream10_000() { + getClient() + .requestStream(createTestPayload(3)) + .doOnNext(this::assertPayload) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .expectNextCount(10_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and receives 5 responses") + @Test + default void requestStream5() { + getClient() + .requestStream(createTestPayload(3)) + .doOnNext(this::assertPayload) + .doOnNext(Payload::release) + .take(5) + .as(StepVerifier::create) + .expectNextCount(5) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and consumes result incrementally") + @Test + default void requestStreamDelayedRequestN() { + getClient() + .requestStream(createTestPayload(3)) + .take(10) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenRequest(5) + .expectNextCount(5) + .thenRequest(5) + .expectNextCount(5) + .expectComplete() + .verify(getTimeout()); + } + + default void assertPayload(Payload p) { + TransportPair transportPair = getTransportPair(); + if (!transportPair.expectedPayloadData().equals(p.getDataUtf8()) + || !transportPair.expectedPayloadMetadata().equals(p.getMetadataUtf8())) { + throw new IllegalStateException("Unexpected payload"); + } + } + + default void assertChannelPayload(Payload p) { + if (!MOCK_DATA.equals(p.getDataUtf8()) || !MOCK_METADATA.equals(p.getMetadataUtf8())) { + throw new IllegalStateException("Unexpected payload"); + } + } + + class TransportPair implements Disposable { + + private static final String data = "hello world"; + private static final String metadata = "metadata"; + + private final boolean withResumability; + private final boolean runClientWithAsyncInterceptors; + private final boolean runServerWithAsyncInterceptors; + + private final LeaksTrackingByteBufAllocator byteBufAllocator1 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Client"); + private final LeaksTrackingByteBufAllocator byteBufAllocator2 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Server"); + + private final TestRSocket responder; + + private final RSocket client; + + private final S server; + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier) { + this(addressSupplier, clientTransportSupplier, serverTransportSupplier, false); + } + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier, + boolean withRandomFragmentation) { + this( + addressSupplier, + clientTransportSupplier, + serverTransportSupplier, + withRandomFragmentation, + false); + } + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier, + boolean withRandomFragmentation, + boolean withResumability) { + Schedulers.onHandleError((t, e) -> e.printStackTrace()); + Schedulers.resetFactory(); + + this.withResumability = withResumability; + + T address = addressSupplier.get(); + + this.runClientWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + this.runServerWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + + ByteBufAllocator allocatorToSupply1; + ByteBufAllocator allocatorToSupply2; + if (ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.ADVANCED + || ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.PARANOID) { + logger.info("Using LeakTrackingByteBufAllocator"); + allocatorToSupply1 = byteBufAllocator1; + allocatorToSupply2 = byteBufAllocator2; + } else { + allocatorToSupply1 = ByteBufAllocator.DEFAULT; + allocatorToSupply2 = ByteBufAllocator.DEFAULT; + } + responder = new TestRSocket(TransportPair.data, metadata); + final RSocketServer rSocketServer = + RSocketServer.create((setup, sendingSocket) -> Mono.just(responder)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .interceptors( + registry -> { + if (runServerWithAsyncInterceptors && !withResumability) { + logger.info( + "Perform Integration Test with Async Interceptors Enabled For Server"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection, "server")) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Server", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(100, 1000))) + : duplexConnection); + } + }); + + if (withResumability) { + rSocketServer.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("server", token, Integer.MAX_VALUE))); + } + + if (withRandomFragmentation) { + rSocketServer.fragment(ThreadLocalRandom.current().nextInt(256, 512)); + } + + server = + rSocketServer.bind(serverTransportSupplier.apply(address, allocatorToSupply2)).block(); + + final RSocketConnector rSocketConnector = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMillis(10), Duration.ofHours(1)) + .interceptors( + registry -> { + if (runClientWithAsyncInterceptors && !withResumability) { + logger.info( + "Perform Integration Test with Async Interceptors Enabled For Client"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection, "client")) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Client", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(10, 1500))) + : duplexConnection); + } + }); + + if (withResumability) { + rSocketConnector.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("client", token, Integer.MAX_VALUE))); + } + + if (withRandomFragmentation) { + rSocketConnector.fragment(ThreadLocalRandom.current().nextInt(256, 512)); + } + + client = + rSocketConnector + .connect(clientTransportSupplier.apply(address, server, allocatorToSupply1)) + .doOnError(Throwable::printStackTrace) + .block(); + } + + @Override + public void dispose() { + logger.info("terminating transport pair"); + client.dispose(); + } + + RSocket getClient() { + return client; + } + + public String expectedPayloadData() { + return data; + } + + public String expectedPayloadMetadata() { + return metadata; + } + + public void awaitClosed(Duration timeout) { + logger.info("awaiting termination of transport pair"); + logger.info( + "wrappers combination: client{async=" + + runClientWithAsyncInterceptors + + "; resume=" + + withResumability + + "} server{async=" + + runServerWithAsyncInterceptors + + "; resume=" + + withResumability + + "}"); + client + .onClose() + .doOnSubscribe(s -> logger.info("Client termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Client termination stage=" + s)) + .onErrorResume(t -> Mono.empty()) + .doOnTerminate(() -> logger.info("Client terminated. Terminating Server")) + .then(Mono.fromRunnable(server::dispose)) + .then( + server + .onClose() + .doOnSubscribe( + s -> logger.info("Server termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Server termination stage=" + s))) + .onErrorResume(t -> Mono.empty()) + .block(timeout); + + logger.info("TransportPair has been terminated"); + } + + private static class AsyncDuplexConnection implements DuplexConnection { + + private final DuplexConnection duplexConnection; + private String tag; + private final ByteBufReleaserOperator bufReleaserOperator; + + public AsyncDuplexConnection(DuplexConnection duplexConnection, String tag) { + this.duplexConnection = duplexConnection; + this.tag = tag; + this.bufReleaserOperator = new ByteBufReleaserOperator(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + duplexConnection.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + duplexConnection.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return duplexConnection + .receive() + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done before PO")) + .subscribeOn(Schedulers.boundedElastic()) + .doOnNext(ByteBuf::retain) + .publishOn(Schedulers.boundedElastic(), Integer.MAX_VALUE) + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done after PO")) + .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::safeRelease) + .transform( + Operators.lift( + (__, actual) -> { + bufReleaserOperator.actual = actual; + return bufReleaserOperator; + })); + } + + @Override + public ByteBufAllocator alloc() { + return duplexConnection.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return duplexConnection.remoteAddress(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError( + duplexConnection + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")), + bufReleaserOperator + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] BufferReleaser is done"))); + } + + @Override + public void dispose() { + duplexConnection.dispose(); + } + + @Override + public String toString() { + return "AsyncDuplexConnection{" + + "duplexConnection=" + + duplexConnection + + ", tag='" + + tag + + '\'' + + ", bufReleaserOperator=" + + bufReleaserOperator + + '}'; + } + } + + private static class DisconnectingDuplexConnection implements DuplexConnection { + + private final String tag; + final DuplexConnection source; + final Duration delay; + final Disposable.Swap disposables = Disposables.swap(); + + DisconnectingDuplexConnection(String tag, DuplexConnection source, Duration delay) { + this.tag = tag; + this.source = source; + this.delay = delay; + } + + @Override + public void dispose() { + disposables.dispose(); + source.dispose(); + } + + @Override + public Mono onClose() { + return source + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) { + source.sendErrorAndClose(errorException); + } + + boolean receivedFirst; + + @Override + public Flux receive() { + return source + .receive() + .doOnSubscribe( + __ -> logger.warn("Tag {}. Subscribing Connection[{}]", tag, source.hashCode())) + .doOnNext( + bb -> { + if (!receivedFirst) { + receivedFirst = true; + disposables.replace( + Mono.delay(delay) + .takeUntilOther(source.onClose()) + .subscribe( + __ -> { + logger.warn( + "Tag {}. Disposing Connection[{}]", tag, source.hashCode()); + source.dispose(); + })); + } + }); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public String toString() { + return "DisconnectingDuplexConnection{" + + "tag='" + + tag + + '\'' + + ", source=" + + source + + ", disposables=" + + disposables + + '}'; + } + } + + private static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + CoreSubscriber actual; + final Sinks.Empty closeableMonoSink; + + Subscription s; + + public ByteBufReleaserOperator() { + this.closeableMonoSink = Sinks.unsafe().empty(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + try { + actual.onNext(buf); + } finally { + buf.release(); + } + } + + Mono onClose() { + return closeableMonoSink.asMono(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + closeableMonoSink.tryEmitError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + closeableMonoSink.tryEmitEmpty(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + closeableMonoSink.tryEmitEmpty(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public String toString() { + return "ByteBufReleaserOperator{" + + "isActualPresent=" + + (actual != null) + + ", " + + "isSubscriptionPresent=" + + (s != null) + + '}'; + } + } + } + + class PayloadPredicate implements Predicate { + final int expectedCnt; + int cnt; + + public PayloadPredicate(int expectedCnt) { + this.expectedCnt = expectedCnt; + } + + @Override + public boolean test(Payload p) { + boolean shouldConsume = cnt++ < expectedCnt; + if (!shouldConsume) { + logger.info( + "Metadata: \n\r{}\n\rData:{}", + p.hasMetadata() + ? new ByteBufRepresentation().fallbackToStringOf(p.sliceMetadata()) + : "Empty", + new ByteBufRepresentation().fallbackToStringOf(p.sliceData())); + } + return shouldConsume; + } + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java new file mode 100644 index 000000000..87a1d4dbf --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java @@ -0,0 +1,6 @@ +package io.rsocket.test; + +@FunctionalInterface +public interface TriFunction { + R apply(T1 t1, T2 t2, T3 t3); +} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/TimedOutException.java b/rsocket-test/src/main/java/io/rsocket/test/package-info.java similarity index 68% rename from rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/TimedOutException.java rename to rsocket-test/src/main/java/io/rsocket/test/package-info.java index 79566f87d..600ac2b82 100644 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/TimedOutException.java +++ b/rsocket-test/src/main/java/io/rsocket/test/package-info.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.rsocket.aeron.internal; -public class TimedOutException extends RuntimeException { +/** Utilities for testing RSocket components. */ +@NonNullApi +package io.rsocket.test; - private static final long serialVersionUID = 6252022225519863073L; -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation new file mode 100644 index 000000000..0c33b5ff7 --- /dev/null +++ b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation @@ -0,0 +1,16 @@ +# +# Copyright 2015-2018 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +io.rsocket.test.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-test/src/main/resources/words.shakespeare.txt.gz b/rsocket-test/src/main/resources/words.shakespeare.txt.gz new file mode 100644 index 000000000..422a4b331 Binary files /dev/null and b/rsocket-test/src/main/resources/words.shakespeare.txt.gz differ diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/AeronDuplexConnection.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/AeronDuplexConnection.java deleted file mode 100644 index 4fe2a5952..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/AeronDuplexConnection.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron; - -import io.netty.buffer.Unpooled; -import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import io.rsocket.aeron.internal.reactivestreams.AeronChannel; -import org.agrona.concurrent.UnsafeBuffer; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -/** - * Implementation of {@link DuplexConnection} over Aeron using an {@link - * io.rsocket.aeron.internal.reactivestreams.AeronChannel} - */ -public class AeronDuplexConnection implements DuplexConnection { - private final String name; - private final AeronChannel channel; - private final MonoProcessor emptySubject; - - public AeronDuplexConnection(String name, AeronChannel channel) { - this.name = name; - this.channel = channel; - this.emptySubject = MonoProcessor.create(); - } - - @Override - public Mono send(Publisher frame) { - Flux buffers = - Flux.from(frame).map(f -> new UnsafeBuffer(f.content().nioBuffer())); - - return channel.send(buffers); - } - - @Override - public Flux receive() { - return channel - .receive() - .map(b -> Frame.from(Unpooled.wrappedBuffer(b.byteBuffer()))) - .doOnError(Throwable::printStackTrace); - } - - @Override - public double availability() { - return channel.isActive() ? 1.0 : 0.0; - } - - @Override - public Mono close() { - return Mono.defer( - () -> { - try { - channel.close(); - emptySubject.onComplete(); - } catch (Exception e) { - emptySubject.onError(e); - } - return emptySubject; - }); - } - - @Override - public Mono onClose() { - return emptySubject; - } - - @Override - public String toString() { - return "AeronDuplexConnection{" - + "name='" - + name - + '\'' - + ", channel=" - + channel - + ", emptySubject=" - + emptySubject - + '}'; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/client/AeronClientTransport.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/client/AeronClientTransport.java deleted file mode 100644 index 6b9ec1e73..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/client/AeronClientTransport.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.client; - -import io.rsocket.DuplexConnection; -import io.rsocket.aeron.AeronDuplexConnection; -import io.rsocket.aeron.internal.reactivestreams.AeronChannel; -import io.rsocket.aeron.internal.reactivestreams.AeronClientChannelConnector; -import io.rsocket.transport.ClientTransport; -import java.util.Objects; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Mono; - -/** {@link ClientTransport} implementation that uses Aeron as a transport */ -public class AeronClientTransport implements ClientTransport { - private final AeronClientChannelConnector connector; - private final AeronClientChannelConnector.AeronClientConfig config; - - public AeronClientTransport( - AeronClientChannelConnector connector, AeronClientChannelConnector.AeronClientConfig config) { - Objects.requireNonNull(config); - Objects.requireNonNull(connector); - this.connector = connector; - this.config = config; - } - - @Override - public Mono connect() { - Publisher channelPublisher = connector.apply(config); - - return Mono.from(channelPublisher) - .map(aeronChannel -> new AeronDuplexConnection("client", aeronChannel)); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/AeronWrapper.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/AeronWrapper.java deleted file mode 100644 index f4ec61a89..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/AeronWrapper.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal; - -import io.aeron.Aeron; -import io.aeron.Image; -import io.aeron.Publication; -import io.aeron.Subscription; -import java.util.function.Function; - -/** */ -public interface AeronWrapper { - Aeron getAeron(); - - void availableImageHandler(Function handler); - - void unavailableImageHandlers(Function handler); - - default Subscription addSubscription(String channel, int streamId) { - return getAeron().addSubscription(channel, streamId); - } - - default Publication addPublication(String channel, int streamId) { - return getAeron().addPublication(channel, streamId); - } - - default void close() { - getAeron().close(); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/Constants.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/Constants.java deleted file mode 100644 index c030bc9a1..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/Constants.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal; - -import java.util.concurrent.TimeUnit; -import org.agrona.concurrent.BackoffIdleStrategy; -import org.agrona.concurrent.IdleStrategy; -import org.agrona.concurrent.NoOpIdleStrategy; -import org.agrona.concurrent.SleepingIdleStrategy; - -public final class Constants { - - public static final int SERVER_STREAM_ID = 0; - public static final int CLIENT_STREAM_ID = 1; - public static final int SERVER_MANAGEMENT_STREAM_ID = 10; - public static final int CLIENT_MANAGEMENT_STREAM_ID = 11; - public static final IdleStrategy EVENT_LOOP_IDLE_STRATEGY; - public static final int AERON_MTU_SIZE = Integer.getInteger("aeron.mtu.length", 4096); - - static { - String idlStrategy = System.getProperty("idleStrategy"); - - if (NoOpIdleStrategy.class.getName().equalsIgnoreCase(idlStrategy)) { - EVENT_LOOP_IDLE_STRATEGY = new NoOpIdleStrategy(); - } else if (SleepingIdleStrategy.class.getName().equalsIgnoreCase(idlStrategy)) { - EVENT_LOOP_IDLE_STRATEGY = new SleepingIdleStrategy(TimeUnit.MILLISECONDS.toNanos(10)); - } else { - EVENT_LOOP_IDLE_STRATEGY = new BackoffIdleStrategy(1, 10, 1_000, 100_000); - } - } - - private Constants() {} -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/DefaultAeronWrapper.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/DefaultAeronWrapper.java deleted file mode 100644 index a3e0d9af8..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/DefaultAeronWrapper.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal; - -import io.aeron.Aeron; -import io.aeron.Image; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Set; -import java.util.concurrent.CopyOnWriteArraySet; -import java.util.function.Function; - -/** */ -public class DefaultAeronWrapper implements AeronWrapper { - private Set> availableImageHandlers; - private Set> unavailableImageHandlers; - - private Aeron aeron; - - public DefaultAeronWrapper() { - this.availableImageHandlers = new CopyOnWriteArraySet<>(); - this.unavailableImageHandlers = new CopyOnWriteArraySet<>(); - - Aeron.Context ctx = new Aeron.Context(); - - ctx.availableImageHandler(this::availableImageHandler); - ctx.unavailableImageHandler(this::unavailableImageHandler); - - this.aeron = Aeron.connect(ctx); - } - - public Aeron getAeron() { - return aeron; - } - - public void availableImageHandler(Function handler) { - availableImageHandlers.add(handler); - } - - public void unavailableImageHandlers(Function handler) { - unavailableImageHandlers.add(handler); - } - - private void availableImageHandler(Image image) { - Iterator> iterator = availableImageHandlers.iterator(); - - Set> itemsToRemove = new HashSet<>(); - while (iterator.hasNext()) { - Function handler = iterator.next(); - if (handler.apply(image)) { - itemsToRemove.add(handler); - } - } - - availableImageHandlers.removeAll(itemsToRemove); - } - - private void unavailableImageHandler(Image image) { - unavailableImageHandlers.removeIf(handler -> handler.apply(image)); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/EventLoop.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/EventLoop.java deleted file mode 100644 index 4bb4adec1..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/EventLoop.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal; - -import java.util.function.IntSupplier; - -/** Interface for an EventLoop used by Aeron */ -public interface EventLoop { - /** - * Executes an IntSupplier that returns a number greater than 0 if it wants the the event loop to - * keep processing items, and zero its okay for the eventloop to execute an idle strategy - * - * @param r signal for roughly how many items could be processed. - * @return whether items could be processed - */ - boolean execute(IntSupplier r); -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/SingleThreadedEventLoop.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/SingleThreadedEventLoop.java deleted file mode 100644 index ba8b8e24d..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/SingleThreadedEventLoop.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal; - -import java.util.concurrent.locks.LockSupport; -import java.util.function.IntSupplier; -import org.agrona.concurrent.IdleStrategy; -import org.agrona.concurrent.OneToOneConcurrentArrayQueue; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** */ -public class SingleThreadedEventLoop implements EventLoop { - private static final Logger logger = LoggerFactory.getLogger(SingleThreadedEventLoop.class); - private final String name; - private final Thread thread; - private final OneToOneConcurrentArrayQueue events = - new OneToOneConcurrentArrayQueue<>(32768); - - public SingleThreadedEventLoop(String name) { - this.name = name; - logger.info("Starting event loop named => {}", name); - - thread = new Thread(new SingleThreadedEventLoopRunnable()); - thread.setDaemon(true); - thread.setName("aeron-single-threaded-event-loop-" + name); - thread.start(); - } - - @Override - public boolean execute(IntSupplier r) { - boolean offer; - - if (thread == Thread.currentThread()) { - offer = events.offer(r); - } else { - synchronized (this) { - offer = events.offer(r); - } - LockSupport.unpark(thread); - } - - return offer; - } - - private int drain() { - int count = 0; - while (!events.isEmpty()) { - IntSupplier poll = events.poll(); - if (poll != null) { - count += poll.getAsInt(); - } - } - - return count; - } - - private class SingleThreadedEventLoopRunnable implements Runnable { - final IdleStrategy idleStrategy = Constants.EVENT_LOOP_IDLE_STRATEGY; - - @Override - public void run() { - while (true) { - try { - int count = drain(); - // if (count > 100) { - // System.out.println(name + " drained..." + count); - // } - idleStrategy.idle(count); - } catch (Throwable t) { - System.err.println("Something bad happened - an error made it to the event loop"); - t.printStackTrace(); - } - } - } - } - - @Override - public String toString() { - return "SingleThreadedEventLoop{" + "name='" + name + '\'' + '}'; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronChannel.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronChannel.java deleted file mode 100644 index 4dff523bd..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronChannel.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal.reactivestreams; - -import io.aeron.Publication; -import io.aeron.Subscription; -import io.rsocket.aeron.internal.EventLoop; -import java.io.IOException; -import java.util.Objects; -import org.agrona.DirectBuffer; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** */ -public class AeronChannel implements ReactiveStreamsRemote.Channel, AutoCloseable { - private final String name; - private final Publication destination; - private final Subscription source; - private final AeronOutPublisher outPublisher; - private final EventLoop eventLoop; - - /** - * Creates on end of a bi-directional channel - * - * @param name name of the channel - * @param destination {@code Publication} to send data to - * @param source Aeron {@code Subscription} to listen to data on - * @param eventLoop {@link EventLoop} used to poll data on - * @param sessionId sessionId between the {@code Publication} and the remote {@code Subscription} - */ - public AeronChannel( - String name, - Publication destination, - Subscription source, - EventLoop eventLoop, - int sessionId) { - this.destination = destination; - this.source = source; - this.name = name; - this.eventLoop = eventLoop; - this.outPublisher = new AeronOutPublisher(name, sessionId, source, eventLoop); - } - - /** - * Subscribes to a stream of DirectBuffers and sends the to an Aeron Publisher - * - * @param in the publisher of buffers. - * @return Mono the completes when all publishers have been sent. - */ - public Mono send(Flux in) { - AeronInSubscriber inSubscriber = new AeronInSubscriber(name, destination); - Objects.requireNonNull(in, "in must not be null"); - return Mono.create( - sink -> in.doOnComplete(sink::success).doOnError(sink::error).subscribe(inSubscriber)); - } - - /** - * Returns ReactiveStreamsRemote.Out of DirectBuffer that can only be subscribed to once per - * channel - * - * @return ReactiveStreamsRemote.Out of DirectBuffer - */ - public Flux receive() { - return outPublisher; - } - - @Override - public void close() throws IOException { - try { - destination.close(); - source.close(); - } catch (Throwable t) { - throw new IOException(t); - } - } - - @Override - public String toString() { - return "AeronChannel{" + "name='" + name + '\'' + '}'; - } - - @Override - public boolean isActive() { - return !destination.isClosed() && !source.isClosed(); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelServer.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelServer.java deleted file mode 100644 index 67ec5700b..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelServer.java +++ /dev/null @@ -1,285 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal.reactivestreams; - -import io.aeron.FragmentAssembler; -import io.aeron.Publication; -import io.aeron.Subscription; -import io.aeron.logbuffer.FragmentHandler; -import io.aeron.logbuffer.Header; -import io.rsocket.Closeable; -import io.rsocket.aeron.internal.AeronWrapper; -import io.rsocket.aeron.internal.Constants; -import io.rsocket.aeron.internal.EventLoop; -import io.rsocket.aeron.internal.NotConnectedException; -import io.rsocket.aeron.internal.reactivestreams.messages.AckConnectEncoder; -import io.rsocket.aeron.internal.reactivestreams.messages.ConnectDecoder; -import io.rsocket.aeron.internal.reactivestreams.messages.MessageHeaderDecoder; -import io.rsocket.aeron.internal.reactivestreams.messages.MessageHeaderEncoder; -import java.net.SocketAddress; -import java.nio.ByteBuffer; -import java.time.Duration; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import org.agrona.DirectBuffer; -import org.agrona.concurrent.UnsafeBuffer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -/** - * Implementation of {@link - * io.rsocket.aeron.internal.reactivestreams.ReactiveStreamsRemote.ChannelServer} that manages - * {@link AeronChannel}s. - */ -public class AeronChannelServer - extends ReactiveStreamsRemote.ChannelServer { - private static final Logger logger = LoggerFactory.getLogger(AeronChannelServer.class); - private final AeronWrapper aeronWrapper; - private final AeronSocketAddress managementSubscriptionSocket; - private final AtomicBoolean started = new AtomicBoolean(false); - private final ConcurrentHashMap serverSubscriptions; - private volatile boolean running = true; - private final EventLoop eventLoop; - private Subscription managementSubscription; - private AeronChannelStartedServer startServer; - - private AeronChannelServer( - AeronChannelConsumer channelConsumer, - AeronWrapper aeronWrapper, - AeronSocketAddress managementSubscriptionSocket, - EventLoop eventLoop) { - super(channelConsumer); - this.aeronWrapper = aeronWrapper; - this.managementSubscriptionSocket = managementSubscriptionSocket; - this.eventLoop = eventLoop; - this.serverSubscriptions = new ConcurrentHashMap<>(); - } - - public static AeronChannelServer create( - AeronChannelConsumer channelConsumer, - AeronWrapper aeronWrapper, - AeronSocketAddress managementSubscriptionSocket, - EventLoop eventLoop) { - return new AeronChannelServer( - channelConsumer, aeronWrapper, managementSubscriptionSocket, eventLoop); - } - - @Override - public AeronChannelStartedServer start() { - if (!started.compareAndSet(false, true)) { - throw new IllegalStateException("server already started"); - } - - logger.debug( - "management server starting on {}, stream id {}", - managementSubscriptionSocket.getChannel(), - Constants.SERVER_MANAGEMENT_STREAM_ID); - - this.managementSubscription = - aeronWrapper.addSubscription( - managementSubscriptionSocket.getChannel(), Constants.SERVER_MANAGEMENT_STREAM_ID); - - this.startServer = new AeronChannelStartedServer(); - - poll(); - - return startServer; - } - - private final FragmentAssembler fragmentAssembler = - new FragmentAssembler( - new FragmentHandler() { - private final MessageHeaderDecoder messageHeaderDecoder = new MessageHeaderDecoder(); - private final ConnectDecoder connectDecoder = new ConnectDecoder(); - private final MessageHeaderEncoder messageHeaderEncoder = new MessageHeaderEncoder(); - private final AckConnectEncoder ackConnectEncoder = new AckConnectEncoder(); - - @Override - public void onFragment(DirectBuffer buffer, int offset, int length, Header header) { - messageHeaderDecoder.wrap(buffer, offset); - - // Do not change the order or remove fields - final int actingBlockLength = messageHeaderDecoder.blockLength(); - final int templateId = messageHeaderDecoder.templateId(); - final int schemaId = messageHeaderDecoder.schemaId(); - final int actingVersion = messageHeaderDecoder.version(); - - if (templateId == ConnectDecoder.TEMPLATE_ID) { - offset += messageHeaderDecoder.encodedLength(); - connectDecoder.wrap(buffer, offset, actingBlockLength, actingVersion); - - // Do not change the order or remove fields - long channelId = connectDecoder.channelId(); - String receivingChannel = connectDecoder.receivingChannel(); - int receivingStreamId = connectDecoder.receivingStreamId(); - String sendingChannel = connectDecoder.sendingChannel(); - int sendingStreamId = connectDecoder.sendingStreamId(); - int clientSessionId = connectDecoder.clientSessionId(); - String clientManagementChannel = connectDecoder.clientManagementChannel(); - - logger.debug( - "server creating a AeronChannel with channel id {} receiving on receivingChannel {}, receivingStreamId {}, sendingChannel {}, sendingStreamId {}", - channelId, - receivingChannel, - receivingStreamId, - sendingChannel, - sendingStreamId); - - // Server sends to receiving Channel - Publication destination = - aeronWrapper.addPublication(receivingChannel, receivingStreamId); - int sessionId = destination.sessionId(); - logger.debug( - "server created publication to channel {}, stream id {}, and session id {}", - receivingChannel, - receivingStreamId, - sessionId); - - // Server listens to sending channel - Subscription source = - serverSubscriptions.computeIfAbsent( - sendingChannel, - s -> aeronWrapper.addSubscription(sendingChannel, sendingStreamId)); - logger.debug( - "server created subscription to channel {}, stream id {}", - sendingChannel, - sendingStreamId); - - AeronChannel aeronChannel = - new AeronChannel("server", destination, source, eventLoop, clientSessionId); - logger.debug( - "server create AeronChannel with destination channel {}, source channel {}, and clientSessionId {}"); - - channelConsumer.accept(aeronChannel); - - Publication managementPublication = - aeronWrapper.addPublication( - clientManagementChannel, Constants.CLIENT_MANAGEMENT_STREAM_ID); - logger.debug( - "server created management publication to channel {}", clientManagementChannel); - - final ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4096); - final UnsafeBuffer directBuffer = new UnsafeBuffer(byteBuffer); - int bufferOffset = 0; - - messageHeaderEncoder - .wrap(directBuffer, bufferOffset) - .blockLength(AckConnectEncoder.BLOCK_LENGTH) - .templateId(AckConnectEncoder.TEMPLATE_ID) - .schemaId(AckConnectEncoder.SCHEMA_ID) - .version(AckConnectEncoder.SCHEMA_VERSION); - - bufferOffset += messageHeaderEncoder.encodedLength(); - - ackConnectEncoder - .wrap(directBuffer, bufferOffset) - .channelId(channelId) - .serverSessionId(destination.sessionId()); - - logger.debug( - "server sending AckConnect message to channel {}", clientManagementChannel); - - long offer; - do { - offer = managementPublication.offer(directBuffer); - if (offer == Publication.CLOSED) { - throw new NotConnectedException(); - } - } while (offer < 0); - } - } - }); - - private int poll() { - int poll; - try { - poll = managementSubscription.poll(fragmentAssembler, 4096); - } finally { - if (running) { - boolean execute = eventLoop.execute(this::poll); - if (!execute) { - running = false; - throw new IllegalStateException("unable to keep polling, eventLoop rejection"); - } - } - } - - return poll; - } - - public interface AeronChannelConsumer - extends ReactiveStreamsRemote.ChannelConsumer {} - - public class AeronChannelStartedServer implements ReactiveStreamsRemote.StartedServer, Closeable { - private final MonoProcessor onClose = MonoProcessor.create(); - - private CountDownLatch latch = new CountDownLatch(1); - - public AeronWrapper getAeronWrapper() { - return aeronWrapper; - } - - public EventLoop getEventLoop() { - return eventLoop; - } - - @Override - public SocketAddress getServerAddress() { - return managementSubscriptionSocket; - } - - @Override - public int getServerPort() { - return managementSubscriptionSocket.getPort(); - } - - @Override - public void awaitShutdown(long duration, TimeUnit durationUnit) { - Duration d = Duration.ofMillis(durationUnit.toMillis(duration)); - close().block(d); - } - - @Override - public void awaitShutdown() { - close().block(); - } - - @Override - public void shutdown() { - close().subscribe(); - } - - @Override - public Mono close() { - return Mono.defer( - () -> { - running = false; - managementSubscription.close(); - return onClose; - }); - } - - @Override - public Mono onClose() { - return onClose; - } - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronClientChannelConnector.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronClientChannelConnector.java deleted file mode 100644 index 566afd7e6..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronClientChannelConnector.java +++ /dev/null @@ -1,341 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal.reactivestreams; - -import io.aeron.FragmentAssembler; -import io.aeron.Publication; -import io.aeron.Subscription; -import io.aeron.logbuffer.FragmentHandler; -import io.aeron.logbuffer.Header; -import io.rsocket.aeron.internal.AeronWrapper; -import io.rsocket.aeron.internal.Constants; -import io.rsocket.aeron.internal.EventLoop; -import io.rsocket.aeron.internal.NotConnectedException; -import io.rsocket.aeron.internal.reactivestreams.messages.AckConnectDecoder; -import io.rsocket.aeron.internal.reactivestreams.messages.ConnectEncoder; -import io.rsocket.aeron.internal.reactivestreams.messages.MessageHeaderDecoder; -import io.rsocket.aeron.internal.reactivestreams.messages.MessageHeaderEncoder; -import java.nio.ByteBuffer; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.IntConsumer; -import org.agrona.DirectBuffer; -import org.agrona.concurrent.UnsafeBuffer; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.publisher.Mono; -import reactor.core.publisher.Operators; - -/** Brokers a connection to a remote Aeron server. */ -public class AeronClientChannelConnector - implements ReactiveStreamsRemote.ClientChannelConnector< - AeronClientChannelConnector.AeronClientConfig, AeronChannel>, - AutoCloseable { - - private static final Logger logger = LoggerFactory.getLogger(AeronClientChannelConnector.class); - - private static final AtomicLong CHANNEL_ID_COUNTER = new AtomicLong(); - - private final AeronWrapper aeronWrapper; - - // Subscriptions clients listen to responses on - private final ConcurrentHashMap clientSubscriptions; - private final ConcurrentHashMap serverSessionIdConsumerMap; - - private final Subscription managementSubscription; - - private final EventLoop eventLoop; - - private volatile boolean running = true; - - private AeronClientChannelConnector( - AeronWrapper aeronWrapper, - AeronSocketAddress managementSubscriptionSocket, - EventLoop eventLoop) { - this.aeronWrapper = aeronWrapper; - - logger.debug( - "client creating a management subscription on channel {}, stream id {}", - managementSubscriptionSocket.getChannel(), - Constants.CLIENT_MANAGEMENT_STREAM_ID); - - this.managementSubscription = - aeronWrapper.addSubscription( - managementSubscriptionSocket.getChannel(), Constants.CLIENT_MANAGEMENT_STREAM_ID); - this.eventLoop = eventLoop; - this.clientSubscriptions = new ConcurrentHashMap<>(); - this.serverSessionIdConsumerMap = new ConcurrentHashMap<>(); - - poll(); - } - - public static AeronClientChannelConnector create( - AeronWrapper wrapper, AeronSocketAddress managementSubscriptionSocket, EventLoop eventLoop) { - return new AeronClientChannelConnector(wrapper, managementSubscriptionSocket, eventLoop); - } - - private final FragmentAssembler fragmentAssembler = - new FragmentAssembler( - new FragmentHandler() { - private final MessageHeaderDecoder messageHeaderDecoder = new MessageHeaderDecoder(); - private final AckConnectDecoder ackConnectDecoder = new AckConnectDecoder(); - - @Override - public void onFragment(DirectBuffer buffer, int offset, int length, Header header) { - messageHeaderDecoder.wrap(buffer, offset); - - // Do not change the order or remove fields - final int actingBlockLength = messageHeaderDecoder.blockLength(); - final int templateId = messageHeaderDecoder.templateId(); - final int schemaId = messageHeaderDecoder.schemaId(); - final int actingVersion = messageHeaderDecoder.version(); - - if (templateId == AckConnectDecoder.TEMPLATE_ID) { - logger.debug("client received an ack message on session id {}", header.sessionId()); - offset += messageHeaderDecoder.encodedLength(); - ackConnectDecoder.wrap(buffer, offset, actingBlockLength, actingVersion); - long channelId = ackConnectDecoder.channelId(); - int serverSessionId = ackConnectDecoder.serverSessionId(); - - logger.debug( - "client received ack message for channel id {} and server session id {}", - channelId, - serverSessionId); - - IntConsumer intConsumer = serverSessionIdConsumerMap.remove(channelId); - - if (intConsumer != null) { - intConsumer.accept(serverSessionId); - } else { - throw new IllegalStateException("no channel found for channel id " + channelId); - } - } else { - throw new IllegalStateException("received unknown template id " + templateId); - } - } - }); - - private int poll() { - int poll; - try { - poll = managementSubscription.poll(fragmentAssembler, 4096); - } finally { - if (running) { - boolean execute = eventLoop.execute(this::poll); - if (!execute) { - running = false; - throw new IllegalStateException("unable to keep polling, eventLoop rejection"); - } - } - } - - return poll; - } - - @Override - public Mono apply(AeronClientConfig aeronClientConfig) { - return Mono.from( - subscriber -> { - subscriber.onSubscribe(Operators.emptySubscription()); - final long channelId = CHANNEL_ID_COUNTER.get(); - try { - - logger.debug("Creating new client channel with id {}", channelId); - final Publication destination = - aeronWrapper.addPublication( - aeronClientConfig.sendSocketAddress.getChannel(), - aeronClientConfig.sendStreamId); - int destinationStreamId = destination.streamId(); - - logger.debug( - "Client created publication to {}, on stream id {}, and session id {}", - aeronClientConfig.sendSocketAddress, - aeronClientConfig.sendStreamId, - destination.sessionId()); - - final Subscription source = - clientSubscriptions.computeIfAbsent( - aeronClientConfig.receiveSocketAddress, - address -> { - Subscription subscription = - aeronWrapper.addSubscription( - aeronClientConfig.receiveSocketAddress.getChannel(), - aeronClientConfig.receiveStreamId); - logger.debug( - "Client created subscription to {}, on stream id {}", - aeronClientConfig.receiveSocketAddress, - aeronClientConfig.receiveStreamId); - return subscription; - }); - - IntConsumer sessionIdConsumer = - sessionId -> { - try { - AeronChannel aeronChannel = - new AeronChannel( - "client", destination, source, aeronClientConfig.eventLoop, sessionId); - logger.debug( - "created client AeronChannel for destination {}, source {}, destination stream id {}, source stream id {}, client session id, and server session id {}", - aeronClientConfig.sendSocketAddress, - aeronClientConfig.receiveSocketAddress, - destination.streamId(), - source.streamId(), - destination.sessionId(), - sessionId); - subscriber.onNext(aeronChannel); - subscriber.onComplete(); - } catch (Throwable t) { - subscriber.onError(t); - } - }; - - serverSessionIdConsumerMap.putIfAbsent(channelId, sessionIdConsumer); - - aeronWrapper.unavailableImageHandlers( - image -> { - if (destinationStreamId == image.sessionId()) { - clientSubscriptions.remove(aeronClientConfig.receiveSocketAddress); - return true; - } else { - return false; - } - }); - - Publication managementPublication = - aeronWrapper.addPublication( - aeronClientConfig.sendSocketAddress.getChannel(), - Constants.SERVER_MANAGEMENT_STREAM_ID); - logger.debug( - "Client created management publication to channel {}, stream id {}", - managementPublication.channel(), - managementPublication.streamId()); - - DirectBuffer buffer = - encodeConnectMessage(channelId, aeronClientConfig, destination.sessionId()); - long offer; - do { - offer = managementPublication.offer(buffer); - if (offer == Publication.CLOSED) { - subscriber.onError(new NotConnectedException()); - } - } while (offer < 0); - logger.debug("Client sent create message to {}", managementPublication.channel()); - - } catch (Throwable t) { - logger.error("Error creating a channel to {}", aeronClientConfig); - clientSubscriptions.remove(aeronClientConfig.receiveSocketAddress); - subscriber.onError(t); - } - }); - } - - public DirectBuffer encodeConnectMessage( - long channelId, AeronClientConfig config, int clientSessionId) { - final ByteBuffer byteBuffer = ByteBuffer.allocateDirect(4096); - final UnsafeBuffer directBuffer = new UnsafeBuffer(byteBuffer); - int bufferOffset = 0; - - MessageHeaderEncoder messageHeaderEncoder = new MessageHeaderEncoder(); - - // Do not channel the order - messageHeaderEncoder - .wrap(directBuffer, bufferOffset) - .blockLength(ConnectEncoder.BLOCK_LENGTH) - .templateId(ConnectEncoder.TEMPLATE_ID) - .schemaId(ConnectEncoder.SCHEMA_ID) - .version(ConnectEncoder.SCHEMA_VERSION); - - bufferOffset += messageHeaderEncoder.encodedLength(); - - ConnectEncoder connectEncoder = new ConnectEncoder(); - - // Do not change the order - connectEncoder - .wrap(directBuffer, bufferOffset) - .channelId(channelId) - .receivingChannel(config.receiveSocketAddress.getChannel()) - .receivingStreamId(config.receiveStreamId) - .sendingChannel(config.sendSocketAddress.getChannel()) - .sendingStreamId(config.sendStreamId) - .clientSessionId(clientSessionId) - .clientManagementChannel(managementSubscription.channel()); - - return directBuffer; - } - - public static class AeronClientConfig implements ReactiveStreamsRemote.ClientChannelConfig { - private final AeronSocketAddress receiveSocketAddress; - private final AeronSocketAddress sendSocketAddress; - private final int receiveStreamId; - private final int sendStreamId; - private final EventLoop eventLoop; - - private AeronClientConfig( - AeronSocketAddress receiveSocketAddress, - AeronSocketAddress sendSocketAddress, - int receiveStreamId, - int sendStreamId, - EventLoop eventLoop) { - this.receiveSocketAddress = receiveSocketAddress; - this.sendSocketAddress = sendSocketAddress; - this.receiveStreamId = receiveStreamId; - this.sendStreamId = sendStreamId; - this.eventLoop = eventLoop; - } - - /** - * Creates client a new {@code AeronClientConfig} for a {@link AeronChannel} - * - * @param receiveSocketAddress the address the channels receives data on - * @param sendSocketAddress the address the channel sends data too - * @param receiveStreamId receiving stream id - * @param sendStreamId the sending stream id - * @param eventLoop event loop for this client - * @return new {@code AeronClientConfig} - */ - public static AeronClientConfig create( - AeronSocketAddress receiveSocketAddress, - AeronSocketAddress sendSocketAddress, - int receiveStreamId, - int sendStreamId, - EventLoop eventLoop) { - return new AeronClientConfig( - receiveSocketAddress, sendSocketAddress, receiveStreamId, sendStreamId, eventLoop); - } - - @Override - public String toString() { - return "AeronClientConfig{" - + "receiveSocketAddress=" - + receiveSocketAddress - + ", sendSocketAddress=" - + sendSocketAddress - + ", receiveStreamId=" - + receiveStreamId - + ", sendStreamId=" - + sendStreamId - + ", eventLoop=" - + eventLoop - + '}'; - } - } - - @Override - public void close() { - running = false; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronInSubscriber.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronInSubscriber.java deleted file mode 100644 index 8eef85f2f..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronInSubscriber.java +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal.reactivestreams; - -import io.aeron.Publication; -import io.aeron.logbuffer.BufferClaim; -import io.rsocket.aeron.internal.Constants; -import io.rsocket.aeron.internal.NotConnectedException; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; -import org.agrona.concurrent.OneToOneConcurrentArrayQueue; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** */ -public class AeronInSubscriber implements Subscriber { - private static final Logger logger = LoggerFactory.getLogger(AeronInSubscriber.class); - private static final ThreadLocal bufferClaims = - ThreadLocal.withInitial(BufferClaim::new); - private static final int BUFFER_SIZE = 128; - private static final int REFILL = BUFFER_SIZE / 3; - - private static final OneToOneConcurrentArrayQueue> - queues = new OneToOneConcurrentArrayQueue<>(BUFFER_SIZE); - - private final OneToOneConcurrentArrayQueue buffers; - private final String name; - private final Publication destination; - - private Subscription subscription; - - private volatile boolean complete; - private volatile boolean erred = false; - - private volatile long requested; - - public AeronInSubscriber(String name, Publication destination) { - this.name = name; - this.destination = destination; - OneToOneConcurrentArrayQueue poll; - synchronized (queues) { - poll = queues.poll(); - } - buffers = poll != null ? poll : new OneToOneConcurrentArrayQueue<>(BUFFER_SIZE); - } - - @Override - public synchronized void onSubscribe(Subscription subscription) { - this.subscription = subscription; - requested = BUFFER_SIZE; - subscription.request(BUFFER_SIZE); - } - - @Override - public void onNext(DirectBuffer buffer) { - if (!erred) { - if (logger.isTraceEnabled()) { - logger.trace( - name - + " sending to destination => " - + destination.channel() - + " and aeron stream " - + destination.streamId() - + " and session id " - + destination.sessionId()); - } - boolean offer; - synchronized (buffers) { - offer = buffers.offer(buffer); - } - if (!offer) { - onError(new IllegalStateException("missing back-pressure")); - } - - tryEmit(); - } - } - - private boolean emitting = false; - private boolean missed = false; - - void tryEmit() { - synchronized (this) { - if (emitting) { - missed = true; - return; - } - } - - emit(); - } - - void emit() { - try { - for (; ; ) { - synchronized (this) { - missed = false; - } - while (!buffers.isEmpty()) { - DirectBuffer buffer = buffers.poll(); - tryClaimOrOffer(buffer); - requested--; - if (requested < REFILL) { - synchronized (buffers) { - if (!complete) { - long diff = BUFFER_SIZE - requested; - requested = BUFFER_SIZE; - subscription.request(diff); - } - } - } - } - - synchronized (this) { - if (!missed) { - emitting = false; - break; - } - } - } - } catch (Throwable t) { - onError(t); - } - - if (complete && buffers.isEmpty()) { - synchronized (queues) { - queues.offer(buffers); - } - } - } - - private void tryClaimOrOffer(DirectBuffer buffer) { - boolean successful = false; - - int capacity = buffer.capacity(); - if (capacity < Constants.AERON_MTU_SIZE) { - BufferClaim bufferClaim = bufferClaims.get(); - - while (!successful) { - long offer = destination.tryClaim(capacity, bufferClaim); - if (offer >= 0) { - try { - final MutableDirectBuffer b = bufferClaim.buffer(); - int offset = bufferClaim.offset(); - b.putBytes(offset, buffer, 0, capacity); - } finally { - bufferClaim.commit(); - successful = true; - } - } else { - if (offer == Publication.CLOSED) { - onError(new NotConnectedException(name)); - } - - successful = false; - } - } - - } else { - while (!successful) { - long offer = destination.offer(buffer); - - if (offer < 0) { - if (offer == Publication.CLOSED) { - onError(new NotConnectedException(name)); - } - } else { - successful = true; - } - } - } - } - - @Override - public synchronized void onError(Throwable t) { - if (!erred) { - erred = true; - subscription.cancel(); - } - - t.printStackTrace(); - } - - @Override - public synchronized void onComplete() { - complete = true; - tryEmit(); - } - - @Override - public String toString() { - return "AeronInSubscriber{" + "name='" + name + '\'' + '}'; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronOutPublisher.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronOutPublisher.java deleted file mode 100644 index e7e881e54..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronOutPublisher.java +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal.reactivestreams; - -import io.aeron.ControlledFragmentAssembler; -import io.aeron.logbuffer.ControlledFragmentHandler; -import io.aeron.logbuffer.Header; -import io.rsocket.aeron.internal.EventLoop; -import io.rsocket.aeron.internal.NotConnectedException; -import java.nio.ByteBuffer; -import java.util.Objects; -import java.util.function.IntSupplier; -import org.agrona.DirectBuffer; -import org.agrona.concurrent.UnsafeBuffer; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import reactor.core.CoreSubscriber; -import reactor.core.publisher.Flux; - -/** */ -public class AeronOutPublisher extends Flux { - private static final Logger logger = LoggerFactory.getLogger(AeronOutPublisher.class); - private final io.aeron.Subscription source; - private final EventLoop eventLoop; - - private String name; - private volatile long requested; - private volatile long processed; - private Subscriber destination; - private AeronOutProcessorSubscription subscription; - private final int sessionId; - - /** - * Creates a publication for a unique session - * - * @param name publication's name - * @param sessionId sessionId between the source and the remote publication - * @param source Aeron {@code Subscription} publish data from - * @param eventLoop {@link EventLoop} to poll the source with - */ - public AeronOutPublisher( - String name, int sessionId, io.aeron.Subscription source, EventLoop eventLoop) { - this.name = name; - this.source = source; - this.eventLoop = eventLoop; - this.sessionId = sessionId; - } - - @Override - public void subscribe(CoreSubscriber destination) { - Objects.requireNonNull(destination); - synchronized (this) { - if (this.destination != null && subscription.canEmit()) { - throw new IllegalStateException( - "only allows one subscription => channel " - + source.channel() - + " and stream id => " - + source.streamId()); - } - this.destination = destination; - } - - this.subscription = new AeronOutProcessorSubscription(destination); - destination.onSubscribe(subscription); - } - - void onError(Throwable t) { - subscription.erred = true; - if (destination != null) { - destination.onError(t); - } - } - - void cancel() { - if (subscription != null) { - subscription.cancel(); - } - } - - @Override - public String toString() { - return "AeronOutPublisher{" + "name='" + name + '\'' + '}'; - } - - private class AeronOutProcessorSubscription implements Subscription { - private volatile boolean erred = false; - private volatile boolean cancelled = false; - private final Subscriber destination; - private final ControlledFragmentAssembler assembler; - - public AeronOutProcessorSubscription(Subscriber destination) { - this.destination = destination; - this.assembler = new ControlledFragmentAssembler(this::onFragment, 4096); - } - - boolean emitting = false; - boolean missed = false; - - @Override - public void request(long n) { - if (n < 0) { - onError(new IllegalStateException("n must be greater than zero")); - } - - synchronized (AeronOutPublisher.this) { - long r; - if (requested != Long.MAX_VALUE && n > 0) { - r = requested + n; - requested = r < 0 ? Long.MAX_VALUE : r; - } - } - - tryEmit(); - } - - // allocate this once - final IntSupplier supplier = this::emit; - - void tryEmit() { - synchronized (AeronOutPublisher.this) { - if (emitting) { - missed = true; - return; - } - emitting = true; - eventLoop.execute(supplier); - } - } - - ControlledFragmentHandler.Action onFragment( - DirectBuffer buffer, int offset, int length, Header header) { - if (sessionId != header.sessionId()) { - if (source.imageBySessionId(header.sessionId()) == null) { - return ControlledFragmentHandler.Action.CONTINUE; - } - - return ControlledFragmentHandler.Action.ABORT; - } - - try { - ByteBuffer bytes = ByteBuffer.allocate(length); - buffer.getBytes(offset, bytes, length); - - if (canEmit()) { - destination.onNext(new UnsafeBuffer(bytes)); - } - } catch (Throwable t) { - onError(t); - } - - return ControlledFragmentHandler.Action.COMMIT; - } - - int emit() { - int emitted = 0; - for (; ; ) { - synchronized (AeronOutPublisher.this) { - missed = false; - } - - try { - if (source.isClosed()) { - onError(new NotConnectedException(name)); - return 0; - } - - while (processed < requested) { - - int poll = source.controlledPoll(assembler, 4096); - - if (poll < 1) { - break; - } else { - emitted++; - processed++; - } - } - - synchronized (AeronOutPublisher.this) { - emitting = false; - break; - } - - } catch (Throwable t) { - onError(t); - } - } - - if (canEmit()) { - tryEmit(); - } - - return emitted; - } - - @Override - public void cancel() { - cancelled = true; - } - - private boolean canEmit() { - return !cancelled && !erred; - } - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronSocketAddress.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronSocketAddress.java deleted file mode 100644 index 2edfa2c77..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/AeronSocketAddress.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal.reactivestreams; - -import java.net.SocketAddress; - -/** SocketAddress that represents an Aeron Channel */ -public class AeronSocketAddress extends SocketAddress { - private static final String FORMAT = "%s?endpoint=%s:%d"; - private static final long serialVersionUID = -7691068719112973697L; - private final String protocol; - private final String host; - private final int port; - private final String channel; - - private AeronSocketAddress(String protocol, String host, int port) { - this.protocol = protocol; - this.host = host; - this.port = port; - this.channel = String.format(FORMAT, protocol, host, port); - } - - public static AeronSocketAddress create(String protocol, String host, int port) { - return new AeronSocketAddress(protocol, host, port); - } - - public String getProtocol() { - return protocol; - } - - public String getHost() { - return host; - } - - public int getPort() { - return port; - } - - public String getChannel() { - return channel; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - AeronSocketAddress that = (AeronSocketAddress) o; - - return channel != null ? channel.equals(that.channel) : that.channel == null; - } - - @Override - public int hashCode() { - return channel != null ? channel.hashCode() : 0; - } - - @Override - public String toString() { - return "AeronSocketAddress{" - + "protocol='" - + protocol - + '\'' - + ", host='" - + host - + '\'' - + ", port=" - + port - + ", channel='" - + channel - + '\'' - + '}'; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/ReactiveStreamsRemote.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/ReactiveStreamsRemote.java deleted file mode 100644 index 2d030c8c5..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/ReactiveStreamsRemote.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal.reactivestreams; - -import io.rsocket.Closeable; -import java.net.SocketAddress; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; -import java.util.function.Function; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** Interfaces to define a ReactiveStream over a remote channel */ -public interface ReactiveStreamsRemote { - interface Channel { - Mono send(Flux in); - - default Mono send(T t) { - return send(Flux.just(t)); - } - - Flux receive(); - - boolean isActive(); - } - - interface ClientChannelConnector> - extends Function> {} - - interface ClientChannelConfig {} - - interface ChannelConsumer> extends Consumer {} - - abstract class ChannelServer> { - protected final C channelConsumer; - - public ChannelServer(C channelConsumer) { - this.channelConsumer = channelConsumer; - } - - public abstract StartedServer start(); - } - - interface StartedServer extends Closeable { - /** - * Address for this server. - * - * @return Address for this server. - */ - SocketAddress getServerAddress(); - - /** - * Port for this server. - * - * @return Port for this server. - */ - int getServerPort(); - - /** - * Blocks till this server shutsdown. - * - *

This does not shutdown the server. - */ - void awaitShutdown(); - - /** - * Blocks till this server shutsdown till the passed duration. - * - *

This does not shutdown the server. - * - * @param duration the number of durationUnit to wait - * @param durationUnit the unit e.g. seconds - */ - void awaitShutdown(long duration, TimeUnit durationUnit); - - /** Initiates the shutdown of this server. */ - void shutdown(); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/AckConnectDecoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/AckConnectDecoder.java deleted file mode 100644 index 26c3b6651..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/AckConnectDecoder.java +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.DirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.AckConnectDecoder"} -) -@SuppressWarnings("all") -public class AckConnectDecoder { - public static final int BLOCK_LENGTH = 12; - public static final int TEMPLATE_ID = 2; - public static final int SCHEMA_ID = 1; - public static final int SCHEMA_VERSION = 0; - - private final AckConnectDecoder parentMessage = this; - private DirectBuffer buffer; - protected int offset; - protected int limit; - protected int actingBlockLength; - protected int actingVersion; - - public int sbeBlockLength() { - return BLOCK_LENGTH; - } - - public int sbeTemplateId() { - return TEMPLATE_ID; - } - - public int sbeSchemaId() { - return SCHEMA_ID; - } - - public int sbeSchemaVersion() { - return SCHEMA_VERSION; - } - - public String sbeSemanticType() { - return ""; - } - - public int offset() { - return offset; - } - - public AckConnectDecoder wrap( - final DirectBuffer buffer, - final int offset, - final int actingBlockLength, - final int actingVersion) { - this.buffer = buffer; - this.offset = offset; - this.actingBlockLength = actingBlockLength; - this.actingVersion = actingVersion; - limit(offset + actingBlockLength); - - return this; - } - - public int encodedLength() { - return limit - offset; - } - - public int limit() { - return limit; - } - - public void limit(final int limit) { - this.limit = limit; - } - - public static int channelIdId() { - return 1; - } - - public static String channelIdMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static long channelIdNullValue() { - return -9223372036854775808L; - } - - public static long channelIdMinValue() { - return -9223372036854775807L; - } - - public static long channelIdMaxValue() { - return 9223372036854775807L; - } - - public long channelId() { - return buffer.getLong(offset + 0, java.nio.ByteOrder.LITTLE_ENDIAN); - } - - public static int serverSessionIdId() { - return 2; - } - - public static String serverSessionIdMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int serverSessionIdNullValue() { - return -2147483648; - } - - public static int serverSessionIdMinValue() { - return -2147483647; - } - - public static int serverSessionIdMaxValue() { - return 2147483647; - } - - public int serverSessionId() { - return buffer.getInt(offset + 8, java.nio.ByteOrder.LITTLE_ENDIAN); - } - - public String toString() { - return appendTo(new StringBuilder(100)).toString(); - } - - public StringBuilder appendTo(final StringBuilder builder) { - final int originalLimit = limit(); - limit(offset + actingBlockLength); - builder.append("[AckConnect](sbeTemplateId="); - builder.append(TEMPLATE_ID); - builder.append("|sbeSchemaId="); - builder.append(SCHEMA_ID); - builder.append("|sbeSchemaVersion="); - if (actingVersion != SCHEMA_VERSION) { - builder.append(actingVersion); - builder.append('/'); - } - builder.append(SCHEMA_VERSION); - builder.append("|sbeBlockLength="); - if (actingBlockLength != BLOCK_LENGTH) { - builder.append(actingBlockLength); - builder.append('/'); - } - builder.append(BLOCK_LENGTH); - builder.append("):"); - // Token{signal=BEGIN_FIELD, name='channelId', description='The AeronChannel id', id=1, - // version=0, encodedLength=0, offset=0, componentTokenCount=3, - // encoding=Encoding{presence=REQUIRED, primitiveType=null, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - // Token{signal=ENCODING, name='int64', description='The AeronChannel id', id=-1, version=0, - // encodedLength=8, offset=0, componentTokenCount=1, encoding=Encoding{presence=REQUIRED, - // primitiveType=INT64, byteOrder=LITTLE_ENDIAN, minValue=null, maxValue=null, nullValue=null, - // constValue=null, characterEncoding='null', epoch='unix', timeUnit=nanosecond, - // semanticType='null'}} - builder.append("channelId="); - builder.append(channelId()); - builder.append('|'); - // Token{signal=BEGIN_FIELD, name='serverSessionId', description='The session id for the server - // publication', id=2, version=0, encodedLength=0, offset=8, componentTokenCount=3, - // encoding=Encoding{presence=REQUIRED, primitiveType=null, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - // Token{signal=ENCODING, name='int32', description='The session id for the server publication', - // id=-1, version=0, encodedLength=4, offset=8, componentTokenCount=1, - // encoding=Encoding{presence=REQUIRED, primitiveType=INT32, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("serverSessionId="); - builder.append(serverSessionId()); - - limit(originalLimit); - - return builder; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/AckConnectEncoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/AckConnectEncoder.java deleted file mode 100644 index 613ce03d4..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/AckConnectEncoder.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.MutableDirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.AckConnectEncoder"} -) -@SuppressWarnings("all") -public class AckConnectEncoder { - public static final int BLOCK_LENGTH = 12; - public static final int TEMPLATE_ID = 2; - public static final int SCHEMA_ID = 1; - public static final int SCHEMA_VERSION = 0; - - private final AckConnectEncoder parentMessage = this; - private MutableDirectBuffer buffer; - protected int offset; - protected int limit; - protected int actingBlockLength; - protected int actingVersion; - - public int sbeBlockLength() { - return BLOCK_LENGTH; - } - - public int sbeTemplateId() { - return TEMPLATE_ID; - } - - public int sbeSchemaId() { - return SCHEMA_ID; - } - - public int sbeSchemaVersion() { - return SCHEMA_VERSION; - } - - public String sbeSemanticType() { - return ""; - } - - public int offset() { - return offset; - } - - public AckConnectEncoder wrap(final MutableDirectBuffer buffer, final int offset) { - this.buffer = buffer; - this.offset = offset; - limit(offset + BLOCK_LENGTH); - - return this; - } - - public int encodedLength() { - return limit - offset; - } - - public int limit() { - return limit; - } - - public void limit(final int limit) { - this.limit = limit; - } - - public static long channelIdNullValue() { - return -9223372036854775808L; - } - - public static long channelIdMinValue() { - return -9223372036854775807L; - } - - public static long channelIdMaxValue() { - return 9223372036854775807L; - } - - public AckConnectEncoder channelId(final long value) { - buffer.putLong(offset + 0, value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int serverSessionIdNullValue() { - return -2147483648; - } - - public static int serverSessionIdMinValue() { - return -2147483647; - } - - public static int serverSessionIdMaxValue() { - return 2147483647; - } - - public AckConnectEncoder serverSessionId(final int value) { - buffer.putInt(offset + 8, value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public String toString() { - return appendTo(new StringBuilder(100)).toString(); - } - - public StringBuilder appendTo(final StringBuilder builder) { - AckConnectDecoder writer = new AckConnectDecoder(); - writer.wrap(buffer, offset, BLOCK_LENGTH, SCHEMA_VERSION); - - return writer.appendTo(builder); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/ConnectDecoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/ConnectDecoder.java deleted file mode 100644 index 79927a8f1..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/ConnectDecoder.java +++ /dev/null @@ -1,548 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.ConnectDecoder"} -) -@SuppressWarnings("all") -public class ConnectDecoder { - public static final int BLOCK_LENGTH = 20; - public static final int TEMPLATE_ID = 1; - public static final int SCHEMA_ID = 1; - public static final int SCHEMA_VERSION = 0; - - private final ConnectDecoder parentMessage = this; - private DirectBuffer buffer; - protected int offset; - protected int limit; - protected int actingBlockLength; - protected int actingVersion; - - public int sbeBlockLength() { - return BLOCK_LENGTH; - } - - public int sbeTemplateId() { - return TEMPLATE_ID; - } - - public int sbeSchemaId() { - return SCHEMA_ID; - } - - public int sbeSchemaVersion() { - return SCHEMA_VERSION; - } - - public String sbeSemanticType() { - return ""; - } - - public int offset() { - return offset; - } - - public ConnectDecoder wrap( - final DirectBuffer buffer, - final int offset, - final int actingBlockLength, - final int actingVersion) { - this.buffer = buffer; - this.offset = offset; - this.actingBlockLength = actingBlockLength; - this.actingVersion = actingVersion; - limit(offset + actingBlockLength); - - return this; - } - - public int encodedLength() { - return limit - offset; - } - - public int limit() { - return limit; - } - - public void limit(final int limit) { - this.limit = limit; - } - - public static int channelIdId() { - return 1; - } - - public static String channelIdMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static long channelIdNullValue() { - return -9223372036854775808L; - } - - public static long channelIdMinValue() { - return -9223372036854775807L; - } - - public static long channelIdMaxValue() { - return 9223372036854775807L; - } - - public long channelId() { - return buffer.getLong(offset + 0, java.nio.ByteOrder.LITTLE_ENDIAN); - } - - public static int sendingStreamIdId() { - return 2; - } - - public static String sendingStreamIdMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int sendingStreamIdNullValue() { - return -2147483648; - } - - public static int sendingStreamIdMinValue() { - return -2147483647; - } - - public static int sendingStreamIdMaxValue() { - return 2147483647; - } - - public int sendingStreamId() { - return buffer.getInt(offset + 8, java.nio.ByteOrder.LITTLE_ENDIAN); - } - - public static int receivingStreamIdId() { - return 3; - } - - public static String receivingStreamIdMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int receivingStreamIdNullValue() { - return -2147483648; - } - - public static int receivingStreamIdMinValue() { - return -2147483647; - } - - public static int receivingStreamIdMaxValue() { - return 2147483647; - } - - public int receivingStreamId() { - return buffer.getInt(offset + 12, java.nio.ByteOrder.LITTLE_ENDIAN); - } - - public static int clientSessionIdId() { - return 4; - } - - public static String clientSessionIdMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int clientSessionIdNullValue() { - return -2147483648; - } - - public static int clientSessionIdMinValue() { - return -2147483647; - } - - public static int clientSessionIdMaxValue() { - return 2147483647; - } - - public int clientSessionId() { - return buffer.getInt(offset + 16, java.nio.ByteOrder.LITTLE_ENDIAN); - } - - public static int sendingChannelId() { - return 5; - } - - public static String sendingChannelCharacterEncoding() { - return "UTF-8"; - } - - public static String sendingChannelMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int sendingChannelHeaderLength() { - return 4; - } - - public int sendingChannelLength() { - final int limit = parentMessage.limit(); - return (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - } - - public int getSendingChannel( - final MutableDirectBuffer dst, final int dstOffset, final int length) { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - final int bytesCopied = Math.min(length, dataLength); - parentMessage.limit(limit + headerLength + dataLength); - buffer.getBytes(limit + headerLength, dst, dstOffset, bytesCopied); - - return bytesCopied; - } - - public int getSendingChannel(final byte[] dst, final int dstOffset, final int length) { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - final int bytesCopied = Math.min(length, dataLength); - parentMessage.limit(limit + headerLength + dataLength); - buffer.getBytes(limit + headerLength, dst, dstOffset, bytesCopied); - - return bytesCopied; - } - - public String sendingChannel() { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - parentMessage.limit(limit + headerLength + dataLength); - final byte[] tmp = new byte[dataLength]; - buffer.getBytes(limit + headerLength, tmp, 0, dataLength); - - final String value; - try { - value = new String(tmp, "UTF-8"); - } catch (final java.io.UnsupportedEncodingException ex) { - throw new RuntimeException(ex); - } - - return value; - } - - public static int receivingChannelId() { - return 6; - } - - public static String receivingChannelCharacterEncoding() { - return "UTF-8"; - } - - public static String receivingChannelMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int receivingChannelHeaderLength() { - return 4; - } - - public int receivingChannelLength() { - final int limit = parentMessage.limit(); - return (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - } - - public int getReceivingChannel( - final MutableDirectBuffer dst, final int dstOffset, final int length) { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - final int bytesCopied = Math.min(length, dataLength); - parentMessage.limit(limit + headerLength + dataLength); - buffer.getBytes(limit + headerLength, dst, dstOffset, bytesCopied); - - return bytesCopied; - } - - public int getReceivingChannel(final byte[] dst, final int dstOffset, final int length) { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - final int bytesCopied = Math.min(length, dataLength); - parentMessage.limit(limit + headerLength + dataLength); - buffer.getBytes(limit + headerLength, dst, dstOffset, bytesCopied); - - return bytesCopied; - } - - public String receivingChannel() { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - parentMessage.limit(limit + headerLength + dataLength); - final byte[] tmp = new byte[dataLength]; - buffer.getBytes(limit + headerLength, tmp, 0, dataLength); - - final String value; - try { - value = new String(tmp, "UTF-8"); - } catch (final java.io.UnsupportedEncodingException ex) { - throw new RuntimeException(ex); - } - - return value; - } - - public static int clientManagementChannelId() { - return 6; - } - - public static String clientManagementChannelCharacterEncoding() { - return "UTF-8"; - } - - public static String clientManagementChannelMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int clientManagementChannelHeaderLength() { - return 4; - } - - public int clientManagementChannelLength() { - final int limit = parentMessage.limit(); - return (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - } - - public int getClientManagementChannel( - final MutableDirectBuffer dst, final int dstOffset, final int length) { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - final int bytesCopied = Math.min(length, dataLength); - parentMessage.limit(limit + headerLength + dataLength); - buffer.getBytes(limit + headerLength, dst, dstOffset, bytesCopied); - - return bytesCopied; - } - - public int getClientManagementChannel(final byte[] dst, final int dstOffset, final int length) { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - final int bytesCopied = Math.min(length, dataLength); - parentMessage.limit(limit + headerLength + dataLength); - buffer.getBytes(limit + headerLength, dst, dstOffset, bytesCopied); - - return bytesCopied; - } - - public String clientManagementChannel() { - final int headerLength = 4; - final int limit = parentMessage.limit(); - final int dataLength = - (int) (buffer.getInt(limit, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - parentMessage.limit(limit + headerLength + dataLength); - final byte[] tmp = new byte[dataLength]; - buffer.getBytes(limit + headerLength, tmp, 0, dataLength); - - final String value; - try { - value = new String(tmp, "UTF-8"); - } catch (final java.io.UnsupportedEncodingException ex) { - throw new RuntimeException(ex); - } - - return value; - } - - public String toString() { - return appendTo(new StringBuilder(100)).toString(); - } - - public StringBuilder appendTo(final StringBuilder builder) { - final int originalLimit = limit(); - limit(offset + actingBlockLength); - builder.append("[Connect](sbeTemplateId="); - builder.append(TEMPLATE_ID); - builder.append("|sbeSchemaId="); - builder.append(SCHEMA_ID); - builder.append("|sbeSchemaVersion="); - if (actingVersion != SCHEMA_VERSION) { - builder.append(actingVersion); - builder.append('/'); - } - builder.append(SCHEMA_VERSION); - builder.append("|sbeBlockLength="); - if (actingBlockLength != BLOCK_LENGTH) { - builder.append(actingBlockLength); - builder.append('/'); - } - builder.append(BLOCK_LENGTH); - builder.append("):"); - // Token{signal=BEGIN_FIELD, name='channelId', description='The AeronChannel id', id=1, - // version=0, encodedLength=0, offset=0, componentTokenCount=3, - // encoding=Encoding{presence=REQUIRED, primitiveType=null, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - // Token{signal=ENCODING, name='int64', description='The AeronChannel id', id=-1, version=0, - // encodedLength=8, offset=0, componentTokenCount=1, encoding=Encoding{presence=REQUIRED, - // primitiveType=INT64, byteOrder=LITTLE_ENDIAN, minValue=null, maxValue=null, nullValue=null, - // constValue=null, characterEncoding='null', epoch='unix', timeUnit=nanosecond, - // semanticType='null'}} - builder.append("channelId="); - builder.append(channelId()); - builder.append('|'); - // Token{signal=BEGIN_FIELD, name='sendingStreamId', description='The stream id the connecting - // client will send traffic on', id=2, version=0, encodedLength=0, offset=8, - // componentTokenCount=3, encoding=Encoding{presence=REQUIRED, primitiveType=null, - // byteOrder=LITTLE_ENDIAN, minValue=null, maxValue=null, nullValue=null, constValue=null, - // characterEncoding='null', epoch='unix', timeUnit=nanosecond, semanticType='null'}} - // Token{signal=ENCODING, name='int32', description='The stream id the connecting client will - // send traffic on', id=-1, version=0, encodedLength=4, offset=8, componentTokenCount=1, - // encoding=Encoding{presence=REQUIRED, primitiveType=INT32, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("sendingStreamId="); - builder.append(sendingStreamId()); - builder.append('|'); - // Token{signal=BEGIN_FIELD, name='receivingStreamId', description='The stream id the connecting - // client will receive data on', id=3, version=0, encodedLength=0, offset=12, - // componentTokenCount=3, encoding=Encoding{presence=REQUIRED, primitiveType=null, - // byteOrder=LITTLE_ENDIAN, minValue=null, maxValue=null, nullValue=null, constValue=null, - // characterEncoding='null', epoch='unix', timeUnit=nanosecond, semanticType='null'}} - // Token{signal=ENCODING, name='int32', description='The stream id the connecting client will - // receive data on', id=-1, version=0, encodedLength=4, offset=12, componentTokenCount=1, - // encoding=Encoding{presence=REQUIRED, primitiveType=INT32, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("receivingStreamId="); - builder.append(receivingStreamId()); - builder.append('|'); - // Token{signal=BEGIN_FIELD, name='clientSessionId', description='The session id for the client - // publication', id=4, version=0, encodedLength=0, offset=16, componentTokenCount=3, - // encoding=Encoding{presence=REQUIRED, primitiveType=null, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - // Token{signal=ENCODING, name='int32', description='The session id for the client publication', - // id=-1, version=0, encodedLength=4, offset=16, componentTokenCount=1, - // encoding=Encoding{presence=REQUIRED, primitiveType=INT32, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("clientSessionId="); - builder.append(clientSessionId()); - builder.append('|'); - // Token{signal=BEGIN_VAR_DATA, name='sendingChannel', description='The Aeron channel the client - // will send data on', id=5, version=0, encodedLength=0, offset=20, componentTokenCount=6, - // encoding=Encoding{presence=REQUIRED, primitiveType=null, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='null', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("sendingChannel="); - builder.append(sendingChannel()); - builder.append('|'); - // Token{signal=BEGIN_VAR_DATA, name='receivingChannel', description='The Aeron channel the - // client will receive data on', id=6, version=0, encodedLength=0, offset=-1, - // componentTokenCount=6, encoding=Encoding{presence=REQUIRED, primitiveType=null, - // byteOrder=LITTLE_ENDIAN, minValue=null, maxValue=null, nullValue=null, constValue=null, - // characterEncoding='null', epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("receivingChannel="); - builder.append(receivingChannel()); - builder.append('|'); - // Token{signal=BEGIN_VAR_DATA, name='clientManagementChannel', description='The channel the - // client listens for management data on', id=6, version=0, encodedLength=0, offset=-1, - // componentTokenCount=6, encoding=Encoding{presence=REQUIRED, primitiveType=null, - // byteOrder=LITTLE_ENDIAN, minValue=null, maxValue=null, nullValue=null, constValue=null, - // characterEncoding='null', epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("clientManagementChannel="); - builder.append(clientManagementChannel()); - - limit(originalLimit); - - return builder; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/ConnectEncoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/ConnectEncoder.java deleted file mode 100644 index f377f0197..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/ConnectEncoder.java +++ /dev/null @@ -1,392 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.ConnectEncoder"} -) -@SuppressWarnings("all") -public class ConnectEncoder { - public static final int BLOCK_LENGTH = 20; - public static final int TEMPLATE_ID = 1; - public static final int SCHEMA_ID = 1; - public static final int SCHEMA_VERSION = 0; - - private final ConnectEncoder parentMessage = this; - private MutableDirectBuffer buffer; - protected int offset; - protected int limit; - protected int actingBlockLength; - protected int actingVersion; - - public int sbeBlockLength() { - return BLOCK_LENGTH; - } - - public int sbeTemplateId() { - return TEMPLATE_ID; - } - - public int sbeSchemaId() { - return SCHEMA_ID; - } - - public int sbeSchemaVersion() { - return SCHEMA_VERSION; - } - - public String sbeSemanticType() { - return ""; - } - - public int offset() { - return offset; - } - - public ConnectEncoder wrap(final MutableDirectBuffer buffer, final int offset) { - this.buffer = buffer; - this.offset = offset; - limit(offset + BLOCK_LENGTH); - - return this; - } - - public int encodedLength() { - return limit - offset; - } - - public int limit() { - return limit; - } - - public void limit(final int limit) { - this.limit = limit; - } - - public static long channelIdNullValue() { - return -9223372036854775808L; - } - - public static long channelIdMinValue() { - return -9223372036854775807L; - } - - public static long channelIdMaxValue() { - return 9223372036854775807L; - } - - public ConnectEncoder channelId(final long value) { - buffer.putLong(offset + 0, value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int sendingStreamIdNullValue() { - return -2147483648; - } - - public static int sendingStreamIdMinValue() { - return -2147483647; - } - - public static int sendingStreamIdMaxValue() { - return 2147483647; - } - - public ConnectEncoder sendingStreamId(final int value) { - buffer.putInt(offset + 8, value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int receivingStreamIdNullValue() { - return -2147483648; - } - - public static int receivingStreamIdMinValue() { - return -2147483647; - } - - public static int receivingStreamIdMaxValue() { - return 2147483647; - } - - public ConnectEncoder receivingStreamId(final int value) { - buffer.putInt(offset + 12, value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int clientSessionIdNullValue() { - return -2147483648; - } - - public static int clientSessionIdMinValue() { - return -2147483647; - } - - public static int clientSessionIdMaxValue() { - return 2147483647; - } - - public ConnectEncoder clientSessionId(final int value) { - buffer.putInt(offset + 16, value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int sendingChannelId() { - return 5; - } - - public static String sendingChannelCharacterEncoding() { - return "UTF-8"; - } - - public static String sendingChannelMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int sendingChannelHeaderLength() { - return 4; - } - - public ConnectEncoder putSendingChannel( - final DirectBuffer src, final int srcOffset, final int length) { - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, src, srcOffset, length); - - return this; - } - - public ConnectEncoder putSendingChannel(final byte[] src, final int srcOffset, final int length) { - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, src, srcOffset, length); - - return this; - } - - public ConnectEncoder sendingChannel(final String value) { - final byte[] bytes; - try { - bytes = value.getBytes("UTF-8"); - } catch (final java.io.UnsupportedEncodingException ex) { - throw new RuntimeException(ex); - } - - final int length = bytes.length; - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, bytes, 0, length); - - return this; - } - - public static int receivingChannelId() { - return 6; - } - - public static String receivingChannelCharacterEncoding() { - return "UTF-8"; - } - - public static String receivingChannelMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int receivingChannelHeaderLength() { - return 4; - } - - public ConnectEncoder putReceivingChannel( - final DirectBuffer src, final int srcOffset, final int length) { - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, src, srcOffset, length); - - return this; - } - - public ConnectEncoder putReceivingChannel( - final byte[] src, final int srcOffset, final int length) { - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, src, srcOffset, length); - - return this; - } - - public ConnectEncoder receivingChannel(final String value) { - final byte[] bytes; - try { - bytes = value.getBytes("UTF-8"); - } catch (final java.io.UnsupportedEncodingException ex) { - throw new RuntimeException(ex); - } - - final int length = bytes.length; - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, bytes, 0, length); - - return this; - } - - public static int clientManagementChannelId() { - return 6; - } - - public static String clientManagementChannelCharacterEncoding() { - return "UTF-8"; - } - - public static String clientManagementChannelMetaAttribute(final MetaAttribute metaAttribute) { - switch (metaAttribute) { - case EPOCH: - return "unix"; - case TIME_UNIT: - return "nanosecond"; - case SEMANTIC_TYPE: - return ""; - } - - return ""; - } - - public static int clientManagementChannelHeaderLength() { - return 4; - } - - public ConnectEncoder putClientManagementChannel( - final DirectBuffer src, final int srcOffset, final int length) { - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, src, srcOffset, length); - - return this; - } - - public ConnectEncoder putClientManagementChannel( - final byte[] src, final int srcOffset, final int length) { - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, src, srcOffset, length); - - return this; - } - - public ConnectEncoder clientManagementChannel(final String value) { - final byte[] bytes; - try { - bytes = value.getBytes("UTF-8"); - } catch (final java.io.UnsupportedEncodingException ex) { - throw new RuntimeException(ex); - } - - final int length = bytes.length; - if (length > 1073741824) { - throw new IllegalArgumentException("length > max value for type: " + length); - } - - final int headerLength = 4; - final int limit = parentMessage.limit(); - parentMessage.limit(limit + headerLength + length); - buffer.putInt(limit, length, java.nio.ByteOrder.LITTLE_ENDIAN); - buffer.putBytes(limit + headerLength, bytes, 0, length); - - return this; - } - - public String toString() { - return appendTo(new StringBuilder(100)).toString(); - } - - public StringBuilder appendTo(final StringBuilder builder) { - ConnectDecoder writer = new ConnectDecoder(); - writer.wrap(buffer, offset, BLOCK_LENGTH, SCHEMA_VERSION); - - return writer.appendTo(builder); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MessageHeaderDecoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MessageHeaderDecoder.java deleted file mode 100644 index d3c9cc986..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MessageHeaderDecoder.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.DirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.MessageHeaderDecoder"} -) -@SuppressWarnings("all") -public class MessageHeaderDecoder { - public static final int ENCODED_LENGTH = 8; - private DirectBuffer buffer; - private int offset; - - public MessageHeaderDecoder wrap(final DirectBuffer buffer, final int offset) { - this.buffer = buffer; - this.offset = offset; - - return this; - } - - public int encodedLength() { - return ENCODED_LENGTH; - } - - public static int blockLengthNullValue() { - return 65535; - } - - public static int blockLengthMinValue() { - return 0; - } - - public static int blockLengthMaxValue() { - return 65534; - } - - public int blockLength() { - return (buffer.getShort(offset + 0, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF); - } - - public static int templateIdNullValue() { - return 65535; - } - - public static int templateIdMinValue() { - return 0; - } - - public static int templateIdMaxValue() { - return 65534; - } - - public int templateId() { - return (buffer.getShort(offset + 2, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF); - } - - public static int schemaIdNullValue() { - return 65535; - } - - public static int schemaIdMinValue() { - return 0; - } - - public static int schemaIdMaxValue() { - return 65534; - } - - public int schemaId() { - return (buffer.getShort(offset + 4, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF); - } - - public static int versionNullValue() { - return 65535; - } - - public static int versionMinValue() { - return 0; - } - - public static int versionMaxValue() { - return 65534; - } - - public int version() { - return (buffer.getShort(offset + 6, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MessageHeaderEncoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MessageHeaderEncoder.java deleted file mode 100644 index c646a8a73..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MessageHeaderEncoder.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.MutableDirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.MessageHeaderEncoder"} -) -@SuppressWarnings("all") -public class MessageHeaderEncoder { - public static final int ENCODED_LENGTH = 8; - private MutableDirectBuffer buffer; - private int offset; - - public MessageHeaderEncoder wrap(final MutableDirectBuffer buffer, final int offset) { - this.buffer = buffer; - this.offset = offset; - - return this; - } - - public int encodedLength() { - return ENCODED_LENGTH; - } - - public static int blockLengthNullValue() { - return 65535; - } - - public static int blockLengthMinValue() { - return 0; - } - - public static int blockLengthMaxValue() { - return 65534; - } - - public MessageHeaderEncoder blockLength(final int value) { - buffer.putShort(offset + 0, (short) value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int templateIdNullValue() { - return 65535; - } - - public static int templateIdMinValue() { - return 0; - } - - public static int templateIdMaxValue() { - return 65534; - } - - public MessageHeaderEncoder templateId(final int value) { - buffer.putShort(offset + 2, (short) value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int schemaIdNullValue() { - return 65535; - } - - public static int schemaIdMinValue() { - return 0; - } - - public static int schemaIdMaxValue() { - return 65534; - } - - public MessageHeaderEncoder schemaId(final int value) { - buffer.putShort(offset + 4, (short) value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static int versionNullValue() { - return 65535; - } - - public static int versionMinValue() { - return 0; - } - - public static int versionMaxValue() { - return 65534; - } - - public MessageHeaderEncoder version(final int value) { - buffer.putShort(offset + 6, (short) value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MetaAttribute.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MetaAttribute.java deleted file mode 100644 index bdf3eb985..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/MetaAttribute.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.MetaAttribute"} -) -public enum MetaAttribute { - EPOCH, - TIME_UNIT, - SEMANTIC_TYPE -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/VarDataEncodingDecoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/VarDataEncodingDecoder.java deleted file mode 100644 index 7eb0d9328..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/VarDataEncodingDecoder.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.DirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.VarDataEncodingDecoder"} -) -@SuppressWarnings("all") -public class VarDataEncodingDecoder { - public static final int ENCODED_LENGTH = -1; - private DirectBuffer buffer; - private int offset; - - public VarDataEncodingDecoder wrap(final DirectBuffer buffer, final int offset) { - this.buffer = buffer; - this.offset = offset; - - return this; - } - - public int encodedLength() { - return ENCODED_LENGTH; - } - - public static long lengthNullValue() { - return 4294967294L; - } - - public static long lengthMinValue() { - return 0L; - } - - public static long lengthMaxValue() { - return 1073741824L; - } - - public long length() { - return (buffer.getInt(offset + 0, java.nio.ByteOrder.LITTLE_ENDIAN) & 0xFFFF_FFFFL); - } - - public static short varDataNullValue() { - return (short) 255; - } - - public static short varDataMinValue() { - return (short) 0; - } - - public static short varDataMaxValue() { - return (short) 254; - } - - public String toString() { - return appendTo(new StringBuilder(100)).toString(); - } - - public StringBuilder appendTo(final StringBuilder builder) { - builder.append('('); - // Token{signal=ENCODING, name='length', description='The channel the client listens for - // management data on', id=-1, version=0, encodedLength=4, offset=0, componentTokenCount=1, - // encoding=Encoding{presence=REQUIRED, primitiveType=UINT32, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=1073741824, nullValue=null, constValue=null, - // characterEncoding='UTF-8', epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append("length="); - builder.append(length()); - builder.append('|'); - // Token{signal=ENCODING, name='varData', description='The channel the client listens for - // management data on', id=-1, version=0, encodedLength=-1, offset=4, componentTokenCount=1, - // encoding=Encoding{presence=REQUIRED, primitiveType=UINT8, byteOrder=LITTLE_ENDIAN, - // minValue=null, maxValue=null, nullValue=null, constValue=null, characterEncoding='UTF-8', - // epoch='unix', timeUnit=nanosecond, semanticType='null'}} - builder.append(')'); - - return builder; - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/VarDataEncodingEncoder.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/VarDataEncodingEncoder.java deleted file mode 100644 index 6c100fff4..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/internal/reactivestreams/messages/VarDataEncodingEncoder.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* Generated SBE (Simple Binary Encoding) message codec */ -package io.rsocket.aeron.internal.reactivestreams.messages; - -import org.agrona.MutableDirectBuffer; - -@javax.annotation.Generated( - value = {"io.rsocket.aeron.internal.reactivestreams.messages.VarDataEncodingEncoder"} -) -@SuppressWarnings("all") -public class VarDataEncodingEncoder { - public static final int ENCODED_LENGTH = -1; - private MutableDirectBuffer buffer; - private int offset; - - public VarDataEncodingEncoder wrap(final MutableDirectBuffer buffer, final int offset) { - this.buffer = buffer; - this.offset = offset; - - return this; - } - - public int encodedLength() { - return ENCODED_LENGTH; - } - - public static long lengthNullValue() { - return 4294967294L; - } - - public static long lengthMinValue() { - return 0L; - } - - public static long lengthMaxValue() { - return 1073741824L; - } - - public VarDataEncodingEncoder length(final long value) { - buffer.putInt(offset + 0, (int) value, java.nio.ByteOrder.LITTLE_ENDIAN); - return this; - } - - public static short varDataNullValue() { - return (short) 255; - } - - public static short varDataMinValue() { - return (short) 0; - } - - public static short varDataMaxValue() { - return (short) 254; - } - - public String toString() { - return appendTo(new StringBuilder(100)).toString(); - } - - public StringBuilder appendTo(final StringBuilder builder) { - VarDataEncodingDecoder writer = new VarDataEncodingDecoder(); - writer.wrap(buffer, offset); - - return writer.appendTo(builder); - } -} diff --git a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/server/AeronServerTransport.java b/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/server/AeronServerTransport.java deleted file mode 100644 index 4deee205d..000000000 --- a/rsocket-transport-aeron/src/main/java/io/rsocket/aeron/server/AeronServerTransport.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.server; - -import io.rsocket.Closeable; -import io.rsocket.DuplexConnection; -import io.rsocket.aeron.AeronDuplexConnection; -import io.rsocket.aeron.internal.AeronWrapper; -import io.rsocket.aeron.internal.EventLoop; -import io.rsocket.aeron.internal.reactivestreams.AeronChannelServer; -import io.rsocket.aeron.internal.reactivestreams.AeronSocketAddress; -import io.rsocket.transport.ServerTransport; -import reactor.core.publisher.Mono; - -/** */ -public class AeronServerTransport implements ServerTransport { - private final AeronWrapper aeronWrapper; - private final AeronSocketAddress managementSubscriptionSocket; - private final EventLoop eventLoop; - - private AeronChannelServer aeronChannelServer; - - public AeronServerTransport( - AeronWrapper aeronWrapper, - AeronSocketAddress managementSubscriptionSocket, - EventLoop eventLoop) { - this.aeronWrapper = aeronWrapper; - this.managementSubscriptionSocket = managementSubscriptionSocket; - this.eventLoop = eventLoop; - } - - @Override - public Mono start(ConnectionAcceptor acceptor) { - synchronized (this) { - if (aeronChannelServer != null) { - throw new IllegalStateException("server already ready started"); - } - - aeronChannelServer = - AeronChannelServer.create( - aeronChannel -> { - DuplexConnection connection = new AeronDuplexConnection("server", aeronChannel); - acceptor.apply(connection).subscribe(); - }, - aeronWrapper, - managementSubscriptionSocket, - eventLoop); - } - - return Mono.just(aeronChannelServer.start()); - } -} diff --git a/rsocket-transport-aeron/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-aeron/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index bec0f3f46..000000000 --- a/rsocket-transport-aeron/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.aeron.AeronUriHandler diff --git a/rsocket-transport-aeron/src/main/resources/aeron-channel-schema.xml b/rsocket-transport-aeron/src/main/resources/aeron-channel-schema.xml deleted file mode 100644 index 2e9818812..000000000 --- a/rsocket-transport-aeron/src/main/resources/aeron-channel-schema.xml +++ /dev/null @@ -1,54 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronClientSetupRule.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronClientSetupRule.java deleted file mode 100644 index 023f4c171..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronClientSetupRule.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron; - -import io.rsocket.Closeable; -import io.rsocket.aeron.client.AeronClientTransport; -import io.rsocket.aeron.internal.*; -import io.rsocket.aeron.internal.reactivestreams.AeronClientChannelConnector; -import io.rsocket.aeron.internal.reactivestreams.AeronSocketAddress; -import io.rsocket.aeron.server.AeronServerTransport; -import io.rsocket.test.ClientSetupRule; - -class AeronClientSetupRule extends ClientSetupRule { - - public static final AeronSocketAddress ADDRESS = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - - static { - MediaDriverHolder.getInstance(); - AeronWrapper aeronWrapper = new DefaultAeronWrapper(); - - EventLoop serverEventLoop = new SingleThreadedEventLoop("server"); - server = new AeronServerTransport(aeronWrapper, ADDRESS, serverEventLoop); - - // Create Client Connector - EventLoop clientEventLoop = new SingleThreadedEventLoop("client"); - - AeronClientChannelConnector.AeronClientConfig config = - AeronClientChannelConnector.AeronClientConfig.create( - ADDRESS, - ADDRESS, - Constants.CLIENT_STREAM_ID, - Constants.SERVER_STREAM_ID, - clientEventLoop); - - AeronClientChannelConnector connector = - AeronClientChannelConnector.create(aeronWrapper, ADDRESS, clientEventLoop); - - client = new AeronClientTransport(connector, config); - } - - private static final AeronServerTransport server; - private static final AeronClientTransport client; - - AeronClientSetupRule() { - super(() -> ADDRESS, (address, server) -> client, address -> server); - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronPing.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronPing.java deleted file mode 100644 index 64231f35d..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronPing.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron; - -import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; -import io.rsocket.aeron.client.AeronClientTransport; -import io.rsocket.aeron.internal.*; -import io.rsocket.aeron.internal.reactivestreams.AeronClientChannelConnector; -import io.rsocket.aeron.internal.reactivestreams.AeronSocketAddress; -import io.rsocket.test.PingClient; -import java.time.Duration; -import org.HdrHistogram.Recorder; -import reactor.core.publisher.Mono; - -public final class AeronPing { - - public static void main(String... args) { - // Create Client Connector - AeronWrapper aeronWrapper = new DefaultAeronWrapper(); - - AeronSocketAddress clientManagementSocketAddress = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - EventLoop clientEventLoop = new SingleThreadedEventLoop("client"); - - AeronSocketAddress receiveAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - AeronSocketAddress sendAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - - AeronClientChannelConnector.AeronClientConfig config = - AeronClientChannelConnector.AeronClientConfig.create( - receiveAddress, - sendAddress, - Constants.CLIENT_STREAM_ID, - Constants.SERVER_STREAM_ID, - clientEventLoop); - - AeronClientChannelConnector connector = - AeronClientChannelConnector.create( - aeronWrapper, clientManagementSocketAddress, clientEventLoop); - - AeronClientTransport aeronTransportClient = new AeronClientTransport(connector, config); - - Mono client = RSocketFactory.connect().transport(aeronTransportClient).start(); - PingClient pingClient = new PingClient(client); - Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); - final int count = 1_000_000_000; - pingClient - .startPingPong(count, recorder) - .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) - .blockLast(); - - System.exit(0); - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronPongServer.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronPongServer.java deleted file mode 100644 index f0625235c..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/AeronPongServer.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron; - -import io.aeron.driver.MediaDriver; -import io.aeron.driver.ThreadingMode; -import io.rsocket.RSocketFactory; -import io.rsocket.aeron.internal.AeronWrapper; -import io.rsocket.aeron.internal.DefaultAeronWrapper; -import io.rsocket.aeron.internal.EventLoop; -import io.rsocket.aeron.internal.SingleThreadedEventLoop; -import io.rsocket.aeron.internal.reactivestreams.AeronSocketAddress; -import io.rsocket.aeron.server.AeronServerTransport; -import io.rsocket.test.PingHandler; - -public final class AeronPongServer { - static { - final io.aeron.driver.MediaDriver.Context ctx = - new io.aeron.driver.MediaDriver.Context() - .threadingMode(ThreadingMode.SHARED_NETWORK) - .dirDeleteOnStart(true); - MediaDriver.launch(ctx); - } - - public static void main(String... args) { - MediaDriverHolder.getInstance(); - AeronWrapper aeronWrapper = new DefaultAeronWrapper(); - - AeronSocketAddress serverManagementSocketAddress = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - EventLoop serverEventLoop = new SingleThreadedEventLoop("server"); - AeronServerTransport server = - new AeronServerTransport(aeronWrapper, serverManagementSocketAddress, serverEventLoop); - - AeronServerTransport transport = - new AeronServerTransport(aeronWrapper, serverManagementSocketAddress, serverEventLoop); - RSocketFactory.receive().acceptor(new PingHandler()).transport(transport).start(); - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/ClientServerTest.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/ClientServerTest.java deleted file mode 100644 index d0383a1b1..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/ClientServerTest.java +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron; - -import static org.junit.Assert.assertEquals; - -import io.rsocket.Payload; -import io.rsocket.test.ClientSetupRule; -import io.rsocket.util.PayloadImpl; -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import reactor.core.publisher.Flux; - -@Ignore -public class ClientServerTest { - - @Rule public final ClientSetupRule setup = new AeronClientSetupRule(); - - @Test(timeout = 10000) - public void testFireNForget10() { - long outputCount = - Flux.range(1, 10) - .flatMap(i -> setup.getRSocket().fireAndForget(new PayloadImpl("hello", "metadata"))) - .doOnError(Throwable::printStackTrace) - .count() - .block(); - - assertEquals(0, outputCount); - } - - @Test(timeout = 10000) - public void testPushMetadata10() { - long outputCount = - Flux.range(1, 10) - .flatMap(i -> setup.getRSocket().metadataPush(new PayloadImpl("", "metadata"))) - .doOnError(Throwable::printStackTrace) - .count() - .block(); - - assertEquals(0, outputCount); - } - - @Test(timeout = 5000000) - public void testRequestResponse1() { - long outputCount = - Flux.range(1, 1) - .flatMap( - i -> - setup - .getRSocket() - .requestResponse(new PayloadImpl("hello", "metadata")) - .map(Payload::getDataUtf8)) - .doOnError(Throwable::printStackTrace) - .count() - .block(); - - assertEquals(1, outputCount); - } - - @Test(timeout = 2000) - public void testRequestResponse10() { - long outputCount = - Flux.range(1, 10) - .flatMap( - i -> - setup - .getRSocket() - .requestResponse(new PayloadImpl("hello", "metadata")) - .map(Payload::getDataUtf8)) - .doOnError(Throwable::printStackTrace) - .count() - .block(); - - assertEquals(10, outputCount); - } - - @Test(timeout = 2000) - public void testRequestResponse100() { - long outputCount = - Flux.range(1, 100) - .flatMap( - i -> - setup - .getRSocket() - .requestResponse(new PayloadImpl("hello", "metadata")) - .map(Payload::getDataUtf8)) - .doOnError(Throwable::printStackTrace) - .count() - .block(); - - assertEquals(100, outputCount); - } - - @Test(timeout = 5000) - public void testRequestResponse10_000() { - long outputCount = - Flux.range(1, 10_000) - .flatMap( - i -> - setup - .getRSocket() - .requestResponse(new PayloadImpl("hello", "metadata")) - .map(Payload::getDataUtf8)) - .doOnError(Throwable::printStackTrace) - .count() - .block(); - - assertEquals(10_000, outputCount); - } - - @Test(timeout = 10000) - public void testRequestStream() { - Flux publisher = - setup.getRSocket().requestStream(new PayloadImpl("hello", "metadata")); - - long count = publisher.take(5).count().block(); - - assertEquals(5, count); - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/MediaDriverHolder.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/MediaDriverHolder.java deleted file mode 100644 index 0e16bc865..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/MediaDriverHolder.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron; - -import io.aeron.driver.MediaDriver; -import io.aeron.driver.ThreadingMode; -import java.util.concurrent.TimeUnit; -import org.agrona.concurrent.SleepingIdleStrategy; - -public class MediaDriverHolder { - private static final MediaDriverHolder INSTANCE = new MediaDriverHolder(); - - static { - final io.aeron.driver.MediaDriver.Context ctx = - new io.aeron.driver.MediaDriver.Context() - .threadingMode(ThreadingMode.SHARED) - .dirDeleteOnStart(true) - .conductorIdleStrategy(new SleepingIdleStrategy(TimeUnit.MILLISECONDS.toNanos(1))) - .receiverIdleStrategy(new SleepingIdleStrategy(TimeUnit.MILLISECONDS.toNanos(1))) - .senderIdleStrategy(new SleepingIdleStrategy(TimeUnit.MILLISECONDS.toNanos(1))); - - ctx.driverTimeoutMs(TimeUnit.MINUTES.toMillis(10)); - MediaDriver.launch(ctx); - } - - private MediaDriverHolder() {} - - public static MediaDriverHolder getInstance() { - return INSTANCE; - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelPing.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelPing.java deleted file mode 100644 index ec7b720c3..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelPing.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal.reactivestreams; - -import io.rsocket.aeron.internal.AeronWrapper; -import io.rsocket.aeron.internal.DefaultAeronWrapper; -import io.rsocket.aeron.internal.SingleThreadedEventLoop; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; -import org.HdrHistogram.Recorder; -import org.agrona.concurrent.UnsafeBuffer; -import reactor.core.publisher.Flux; - -/** */ -public final class AeronChannelPing { - public static void main(String... args) { - int count = 1_000_000_000; - final Recorder histogram = new Recorder(Long.MAX_VALUE, 3); - Executors.newSingleThreadScheduledExecutor() - .scheduleAtFixedRate( - () -> { - System.out.println("---- PING/ PONG HISTO ----"); - histogram - .getIntervalHistogram() - .outputPercentileDistribution(System.out, 5, 1000.0, false); - System.out.println("---- PING/ PONG HISTO ----"); - }, - 1, - 1, - TimeUnit.SECONDS); - - AeronWrapper wrapper = new DefaultAeronWrapper(); - AeronSocketAddress managementSocketAddress = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - SingleThreadedEventLoop eventLoop = new SingleThreadedEventLoop("client"); - AeronClientChannelConnector connector = - AeronClientChannelConnector.create(wrapper, managementSocketAddress, eventLoop); - - AeronSocketAddress receiveAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - AeronSocketAddress sendAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - - AeronClientChannelConnector.AeronClientConfig config = - AeronClientChannelConnector.AeronClientConfig.create( - receiveAddress, sendAddress, 1, 2, eventLoop); - - AeronChannel channel = connector.apply(config).block(); - - AtomicLong lastUpdate = new AtomicLong(System.nanoTime()); - channel - .receive() - .doOnNext( - b -> { - synchronized (wrapper) { - int anInt = b.getInt(0); - if (anInt % 1_000 == 0) { - long diff = System.nanoTime() - lastUpdate.get(); - histogram.recordValue(diff); - lastUpdate.set(System.nanoTime()); - } - } - }) - .doOnError(Throwable::printStackTrace) - .subscribe(); - - byte[] b = new byte[1024]; - Flux.range(0, count) - .flatMap( - i -> { - UnsafeBuffer buffer = new UnsafeBuffer(b); - buffer.putInt(0, i); - return channel.send(buffer); - }, - 8) - .last(null) - .block(); - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelPongServer.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelPongServer.java deleted file mode 100644 index 0ef151519..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelPongServer.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal.reactivestreams; - -import io.rsocket.aeron.MediaDriverHolder; -import io.rsocket.aeron.internal.AeronWrapper; -import io.rsocket.aeron.internal.DefaultAeronWrapper; -import io.rsocket.aeron.internal.SingleThreadedEventLoop; -import org.agrona.DirectBuffer; -import reactor.core.publisher.Flux; - -/** */ -public class AeronChannelPongServer { - public static void main(String... args) { - MediaDriverHolder.getInstance(); - AeronWrapper wrapper = new DefaultAeronWrapper(); - AeronSocketAddress managementSubscription = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - SingleThreadedEventLoop eventLoop = new SingleThreadedEventLoop("server"); - - AeronChannelServer.AeronChannelConsumer consumer = - aeronChannel -> { - Flux receive = aeronChannel.receive(); - // .doOnNext(b -> System.out.println("server got => " + b.getInt(0))); - - aeronChannel.send(receive).doOnError(Throwable::printStackTrace).subscribe(); - }; - - AeronChannelServer server = - AeronChannelServer.create(consumer, wrapper, managementSubscription, eventLoop); - AeronChannelServer.AeronChannelStartedServer start = server.start(); - start.awaitShutdown(); - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelTest.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelTest.java deleted file mode 100644 index 9c9f7ace1..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronChannelTest.java +++ /dev/null @@ -1,341 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.aeron.internal.reactivestreams; - -import io.aeron.Aeron; -import io.aeron.Publication; -import io.aeron.Subscription; -import io.rsocket.aeron.MediaDriverHolder; -import io.rsocket.aeron.internal.Constants; -import io.rsocket.aeron.internal.EventLoop; -import io.rsocket.aeron.internal.SingleThreadedEventLoop; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ThreadLocalRandom; -import org.agrona.BitUtil; -import org.agrona.LangUtil; -import org.agrona.concurrent.UnsafeBuffer; -import org.junit.Ignore; -import org.junit.Test; -import reactor.core.publisher.Flux; - -/** */ -@Ignore("travis does not like me") -public class AeronChannelTest { - static { - // System.setProperty("aeron.publication.linger.timeout", String.valueOf(50_000_000_000L)); - // System.setProperty("aeron.client.liveness.timeout", String.valueOf(50_000_000_000L)); - MediaDriverHolder.getInstance(); - } - - @Test - @Ignore - public void testPing() { - - int count = 5_000_000; - CountDownLatch countDownLatch = new CountDownLatch(count); - - CountDownLatch sync = new CountDownLatch(2); - Aeron.Context ctx = new Aeron.Context(); - - // ctx.publicationConnectionTimeout(TimeUnit.MINUTES.toNanos(5)); - - ctx.availableImageHandler( - image -> { - System.out.println( - "name image subscription => " - + image.subscription().channel() - + " streamId => " - + image.subscription().streamId() - + " registrationId => " - + image.subscription().registrationId()); - sync.countDown(); - }); - - ctx.unavailableImageHandler( - image -> - System.out.println( - "=== unavailable image name image subscription => " - + image.subscription().channel() - + " streamId => " - + image.subscription().streamId() - + " registrationId => " - + image.subscription().registrationId())); - /*ctx.errorHandler(t -> { - /* StringWriter writer = new StringWriter(); - PrintWriter w = new PrintWriter(writer); - t.printStackTrace(w); - - w.flush();* - - // System.out.println("\nGOT AERON ERROR => \n [" + writer.toString() + "]\n\n"); - });*/ - - ctx.driverTimeoutMs(Integer.MAX_VALUE); - Aeron aeron = Aeron.connect(ctx); - /* - Subscription serverSubscription = aeron.addSubscription("aeron:ipc", Constants.SERVER_STREAM_ID); - Publication serverPublication = aeron.addPublication("aeron:ipc", Constants.CLIENT_STREAM_ID); - - Subscription clientSubscription = aeron.addSubscription("aeron:ipc", Constants.CLIENT_STREAM_ID); - Publication clientPublication = aeron.addPublication("aeron:ipc", Constants.SERVER_STREAM_ID); - */ - - Subscription serverSubscription = - aeron.addSubscription("aeron:udp?endpoint=localhost:39791", Constants.SERVER_STREAM_ID); - System.out.println( - "serverSubscription registration id => " + serverSubscription.registrationId()); - - Publication serverPublication = - aeron.addPublication("aeron:udp?endpoint=localhost:39790", Constants.CLIENT_STREAM_ID); - - Subscription clientSubscription = - aeron.addSubscription("aeron:udp?endpoint=localhost:39790", Constants.CLIENT_STREAM_ID); - - System.out.println( - "clientSubscription registration id => " + clientSubscription.registrationId()); - Publication clientPublication = - aeron.addPublication("aeron:udp?endpoint=localhost:39791", Constants.SERVER_STREAM_ID); - - try { - sync.await(); - } catch (InterruptedException e) { - e.printStackTrace(); - } - - EventLoop serverLoop = new SingleThreadedEventLoop("server"); - - AeronOutPublisher publisher = - new AeronOutPublisher( - "server", clientPublication.sessionId(), serverSubscription, serverLoop); - publisher - .doOnNext(i -> countDownLatch.countDown()) - .doOnError(Throwable::printStackTrace) - .subscribe(); - - AeronInSubscriber aeronInSubscriber = new AeronInSubscriber("client", clientPublication); - - Flux unsafeBufferObservable = - Flux.range(1, count) - // .doOnNext(i -> LockSupport.parkNanos(TimeUnit.MICROSECONDS.toNanos(50))) - // .doOnNext(i -> System.out.println(Thread.currentThread() + " => client sending => " + - // i)) - .map( - i -> { - UnsafeBuffer buffer = new UnsafeBuffer(new byte[BitUtil.SIZE_OF_INT]); - buffer.putInt(0, i); - return buffer; - }) - // .doOnRequest(l -> System.out.println("Client reuqested => " + l)) - .doOnError(Throwable::printStackTrace) - .doOnComplete(() -> System.out.println("Im done")); - - unsafeBufferObservable.subscribe(aeronInSubscriber); - - try { - countDownLatch.await(); - } catch (InterruptedException e) { - LangUtil.rethrowUnchecked(e); - } - System.out.println("HERE!!!!"); - } - - @Test(timeout = 2_000) - public void testPingPong_10() { - pingPong(10); - } - - @Test(timeout = 2_000) - public void testPingPong_100() { - pingPong(100); - } - - @Test(timeout = 5_000) - public void testPingPong_300() { - pingPong(300); - } - - @Test(timeout = 5_000) - public void testPingPong_1_000() { - pingPong(1_000); - } - - @Test(timeout = 15_000) - public void testPingPong_10_000() { - pingPong(10_000); - } - - @Ignore - @Test(timeout = 5_000) - public void testPingPong_100_000() { - pingPong(100_000); - } - - @Ignore - @Test(timeout = 15_000) - public void testPingPong_1_000_000() { - pingPong(1_000_000); - } - - @Test(timeout = 50_000) - @Ignore - public void testPingPong_10_000_000() { - pingPong(10_000_000); - } - - @Test - @Ignore - public void testPingPongAlot() { - pingPong(100_000_000); - } - - private void pingPong(int count) { - - CountDownLatch sync = new CountDownLatch(2); - Aeron.Context ctx = new Aeron.Context(); - ctx.availableImageHandler( - image -> { - System.out.println( - "name image subscription => " - + image.subscription().channel() - + " streamId => " - + image.subscription().streamId() - + " registrationId => " - + image.subscription().registrationId()); - sync.countDown(); - }); - - ctx.unavailableImageHandler( - image -> - System.out.println( - "=== unavailable image name image subscription => " - + image.subscription().channel() - + " streamId => " - + image.subscription().streamId() - + " registrationId => " - + image.subscription().registrationId())); - - /*ctx.errorHandler(t -> { - /* StringWriter writer = new StringWriter(); - PrintWriter w = new PrintWriter(writer); - t.printStackTrace(w); - - w.flush();* - - // System.out.println("\nGOT AERON ERROR => \n [" + writer.toString() + "]\n\n"); - });*/ - - // ctx.driverTimeoutMs(Integer.MAX_VALUE); - Aeron aeron = Aeron.connect(ctx); - - Subscription serverSubscription = - aeron.addSubscription("aeron:ipc", Constants.SERVER_STREAM_ID); - Publication serverPublication = aeron.addPublication("aeron:ipc", Constants.CLIENT_STREAM_ID); - - Subscription clientSubscription = - aeron.addSubscription("aeron:ipc", Constants.CLIENT_STREAM_ID); - Publication clientPublication = aeron.addPublication("aeron:ipc", Constants.SERVER_STREAM_ID); - - /* - Subscription serverSubscription = aeron.addSubscription("udp://localhost:39791", Constants.SERVER_STREAM_ID); - System.out.println("serverSubscription registration id => " + serverSubscription.registrationId()); - - Publication serverPublication = aeron.addPublication("udp://localhost:39790", Constants.CLIENT_STREAM_ID); - - Subscription clientSubscription = aeron.addSubscription("udp://localhost:39790", Constants.CLIENT_STREAM_ID); - - System.out.println("clientSubscription registration id => " + clientSubscription.registrationId()); - Publication clientPublication = aeron.addPublication("udp://localhost:39791", Constants.SERVER_STREAM_ID); - */ - try { - sync.await(); - } catch (InterruptedException e) { - LangUtil.rethrowUnchecked(e); - } - - SingleThreadedEventLoop serverLoop = new SingleThreadedEventLoop("server"); - SingleThreadedEventLoop clientLoop = new SingleThreadedEventLoop("client"); - - AeronChannel serverChannel = - new AeronChannel( - "server", - serverPublication, - serverSubscription, - serverLoop, - clientPublication.sessionId()); - - System.out.println("created server channel"); - - CountDownLatch latch = new CountDownLatch(count); - - serverChannel - .receive() - // latch.countDown(); - // System.out.println("received -> " + f.getInt(0)); - .flatMap(serverChannel::send, 32) - .doOnError(Throwable::printStackTrace) - .subscribe(); - - AeronChannel clientChannel = - new AeronChannel( - "client", - clientPublication, - clientSubscription, - clientLoop, - serverPublication.sessionId()); - - clientChannel - .receive() - .doOnNext( - l -> { - synchronized (latch) { - latch.countDown(); - if (latch.getCount() % 10_000 == 0) { - System.out.println("mod of client got back -> " + latch.getCount()); - } - // if (latch.getCount() < 10_000) { - // System.out.println("client got back -> " + latch.getCount()); - // } - } - }) - .doOnError(Throwable::printStackTrace) - .subscribe(); - - byte[] bytes = new byte[8]; - ThreadLocalRandom.current().nextBytes(bytes); - - Flux.range(1, count) - // .doOnRequest(l -> System.out.println("requested => " + l)) - .flatMap( - i -> { - // System.out.println("Sending -> " + i); - - // UnsafeBuffer b = new UnsafeBuffer(new byte[BitUtil.SIZE_OF_INT]); - UnsafeBuffer b = new UnsafeBuffer(bytes); - b.putInt(0, i); - - return clientChannel.send(b); - }, - 8) - .doOnError(Throwable::printStackTrace) - .subscribe(); - - try { - latch.await(); - } catch (Exception t) { - LangUtil.rethrowUnchecked(t); - } - } -} diff --git a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronClientServerChannelTest.java b/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronClientServerChannelTest.java deleted file mode 100644 index 1a0dd3bfc..000000000 --- a/rsocket-transport-aeron/src/test/java/io/rsocket/aeron/internal/reactivestreams/AeronClientServerChannelTest.java +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.aeron.internal.reactivestreams; - -import io.rsocket.aeron.MediaDriverHolder; -import io.rsocket.aeron.internal.AeronWrapper; -import io.rsocket.aeron.internal.DefaultAeronWrapper; -import io.rsocket.aeron.internal.EventLoop; -import io.rsocket.aeron.internal.SingleThreadedEventLoop; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ThreadLocalRandom; -import org.agrona.BitUtil; -import org.agrona.DirectBuffer; -import org.agrona.concurrent.UnsafeBuffer; -import org.junit.Assert; -import org.junit.Ignore; -import org.junit.Test; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -/** */ -@Ignore("travis does not like me") -public class AeronClientServerChannelTest { - static { - MediaDriverHolder.getInstance(); - } - - @Test(timeout = 5_000) - public void testConnect() throws Exception { - int clientId = ThreadLocalRandom.current().nextInt(0, 1_000); - int serverId = clientId + 1; - - System.out.println("test client stream id => " + clientId); - System.out.println("test server stream id => " + serverId); - - AeronWrapper aeronWrapper = new DefaultAeronWrapper(); - - // Create Client Connector - AeronSocketAddress clientManagementSocketAddress = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - EventLoop clientEventLoop = new SingleThreadedEventLoop("client"); - - AeronSocketAddress receiveAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - AeronSocketAddress sendAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - - AeronClientChannelConnector.AeronClientConfig config = - AeronClientChannelConnector.AeronClientConfig.create( - receiveAddress, sendAddress, clientId, serverId, clientEventLoop); - - AeronClientChannelConnector connector = - AeronClientChannelConnector.create( - aeronWrapper, clientManagementSocketAddress, clientEventLoop); - - // Create Server - CountDownLatch latch = new CountDownLatch(2); - - AeronChannelServer.AeronChannelConsumer consumer = - (AeronChannel aeronChannel) -> { - Assert.assertNotNull(aeronChannel); - latch.countDown(); - }; - - AeronSocketAddress serverManagementSocketAddress = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - EventLoop serverEventLoop = new SingleThreadedEventLoop("server"); - AeronChannelServer aeronChannelServer = - AeronChannelServer.create( - consumer, aeronWrapper, serverManagementSocketAddress, serverEventLoop); - - aeronChannelServer.start(); - - Publisher publisher = connector.apply(config); - Flux.from(publisher) - .doOnNext(Assert::assertNotNull) - .doOnNext(c -> latch.countDown()) - .doOnError( - t -> { - throw new RuntimeException(t); - }) - .subscribe(); - - latch.await(); - } - - @Test(timeout = 5_000) - public void testPingPong() throws Exception { - int clientId = ThreadLocalRandom.current().nextInt(2_000, 3_000); - int serverId = clientId + 1; - - System.out.println("test client stream id => " + clientId); - System.out.println("test server stream id => " + serverId); - - AeronWrapper aeronWrapper = new DefaultAeronWrapper(); - - // Create Client Connector - AeronSocketAddress clientManagementSocketAddress = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - EventLoop clientEventLoop = new SingleThreadedEventLoop("client"); - - AeronSocketAddress receiveAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - AeronSocketAddress sendAddress = AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - - AeronClientChannelConnector.AeronClientConfig config = - AeronClientChannelConnector.AeronClientConfig.create( - receiveAddress, sendAddress, clientId, serverId, clientEventLoop); - - AeronClientChannelConnector connector = - AeronClientChannelConnector.create( - aeronWrapper, clientManagementSocketAddress, clientEventLoop); - - // Create Server - - AeronChannelServer.AeronChannelConsumer consumer = - (AeronChannel aeronChannel) -> { - Assert.assertNotNull(aeronChannel); - - Flux receive = aeronChannel.receive(); - - Flux data = - receive.doOnNext(b -> System.out.println("server received => " + b.getInt(0))); - - aeronChannel.send(data).subscribe(); - }; - - AeronSocketAddress serverManagementSocketAddress = - AeronSocketAddress.create("aeron:udp", "127.0.0.1", 39790); - EventLoop serverEventLoop = new SingleThreadedEventLoop("server"); - AeronChannelServer aeronChannelServer = - AeronChannelServer.create( - consumer, aeronWrapper, serverManagementSocketAddress, serverEventLoop); - - aeronChannelServer.start(); - - Publisher publisher = connector.apply(config); - - int count = 10; - CountDownLatch latch = new CountDownLatch(count); - - Mono.from(publisher) - .flatMap( - aeronChannel -> - Mono.create( - callback -> { - Flux data = - Flux.range(1, count) - .map( - i -> { - byte[] b = new byte[BitUtil.SIZE_OF_INT]; - UnsafeBuffer buffer = new UnsafeBuffer(b); - buffer.putInt(0, i); - return buffer; - }); - - aeronChannel - .receive() - .doOnNext(b -> latch.countDown()) - .doOnNext(callback::success) - .subscribe(); - aeronChannel.send(data).subscribe(); - })) - .subscribe(); - - latch.await(); - } -} diff --git a/rsocket-transport-aeron/src/test/resources/log4j.properties b/rsocket-transport-aeron/src/test/resources/log4j.properties deleted file mode 100644 index 6477d125f..000000000 --- a/rsocket-transport-aeron/src/test/resources/log4j.properties +++ /dev/null @@ -1,33 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -#

-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -#

-# http://www.apache.org/licenses/LICENSE-2.0 -#

-# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. -# - - -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] (%F:%L) - %m%n \ No newline at end of file diff --git a/rsocket-transport-local/build.gradle b/rsocket-transport-local/build.gradle index 859702f6f..fc32125e2 100644 --- a/rsocket-transport-local/build.gradle +++ b/rsocket-transport-local/build.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,8 +14,28 @@ * limitations under the License. */ +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + dependencies { - compile project(':rsocket-core') + api project(':rsocket-core') + + testImplementation project(':rsocket-test') + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.junit.jupiter:junit-jupiter-api' - testCompile project(':rsocket-test') + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' } + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.local") + } +} + +description = 'Local RSocket transport implementation' diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java index a0764e1a5..1b3779e85 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -1,54 +1,93 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2021 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.transport.local; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import io.rsocket.Frame; +import io.rsocket.internal.UnboundedProcessor; import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.local.LocalServerTransport.ServerDuplexConnectionAcceptor; +import io.rsocket.transport.ServerTransport; +import java.util.Objects; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.core.publisher.UnicastProcessor; +import reactor.core.publisher.Sinks; + +/** + * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} in the + * same JVM. + */ +public final class LocalClientTransport implements ClientTransport { -public class LocalClientTransport implements ClientTransport { private final String name; - LocalClientTransport(String name) { + private final ByteBufAllocator allocator; + + private LocalClientTransport(String name, ByteBufAllocator allocator) { this.name = name; + this.allocator = allocator; } + /** + * Creates a new instance. + * + * @param name the name of the {@link ClientTransport} instance to connect to + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ public static LocalClientTransport create(String name) { - return new LocalClientTransport(name); + Objects.requireNonNull(name, "name must not be null"); + + return create(name, ByteBufAllocator.DEFAULT); + } + + /** + * Creates a new instance. + * + * @param name the name of the {@link ClientTransport} instance to connect to + * @param allocator the allocator used by {@link ClientTransport} instance + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ + public static LocalClientTransport create(String name, ByteBufAllocator allocator) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(allocator, "allocator must not be null"); + + return new LocalClientTransport(name, allocator); } @Override public Mono connect() { return Mono.defer( () -> { - ServerDuplexConnectionAcceptor server = LocalServerTransport.findServer(name); - if (server != null) { - final UnicastProcessor in = UnicastProcessor.create(); - final UnicastProcessor out = UnicastProcessor.create(); - final MonoProcessor closeNotifier = MonoProcessor.create(); - server.accept(new LocalDuplexConnection(out, in, closeNotifier)); - DuplexConnection client = new LocalDuplexConnection(in, out, closeNotifier); - return Mono.just(client); + ServerTransport.ConnectionAcceptor server = LocalServerTransport.findServer(name); + if (server == null) { + return Mono.error(new IllegalArgumentException("Could not find server: " + name)); } - return Mono.error(new IllegalArgumentException("Could not find server: " + name)); + + Sinks.One inSink = Sinks.one(); + Sinks.One outSink = Sinks.one(); + UnboundedProcessor in = new UnboundedProcessor(inSink::tryEmitEmpty); + UnboundedProcessor out = new UnboundedProcessor(outSink::tryEmitEmpty); + + Mono onClose = inSink.asMono().and(outSink.asMono()); + + server.apply(new LocalDuplexConnection(name, allocator, out, in, onClose)).subscribe(); + + return Mono.just( + new LocalDuplexConnection(name, allocator, in, out, onClose)); }); } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java index dffdafe7e..c1d0fd2a3 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java @@ -1,73 +1,198 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2021 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.transport.local; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; +import java.util.Objects; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; - -public class LocalDuplexConnection implements DuplexConnection { - private final Flux in; - private final Subscriber out; - private final MonoProcessor closeNotifier; - - public LocalDuplexConnection( - Flux in, Subscriber out, MonoProcessor closeNotifier) { - this.in = in; - this.out = out; - this.closeNotifier = closeNotifier; +import reactor.core.publisher.Operators; + +/** An implementation of {@link DuplexConnection} that connects inside the same JVM. */ +final class LocalDuplexConnection implements DuplexConnection { + + private final LocalSocketAddress address; + private final ByteBufAllocator allocator; + private final UnboundedProcessor in; + + private final Mono onClose; + + private final UnboundedProcessor out; + + /** + * Creates a new instance. + * + * @param name the name assigned to this local connection + * @param in the inbound {@link ByteBuf}s + * @param out the outbound {@link ByteBuf}s + * @param onClose the closing notifier + * @throws NullPointerException if {@code in}, {@code out}, or {@code onClose} are {@code null} + */ + LocalDuplexConnection( + String name, + ByteBufAllocator allocator, + UnboundedProcessor in, + UnboundedProcessor out, + Mono onClose) { + this.address = new LocalSocketAddress(name); + this.allocator = Objects.requireNonNull(allocator, "allocator must not be null"); + this.in = Objects.requireNonNull(in, "in must not be null"); + this.out = Objects.requireNonNull(out, "out must not be null"); + this.onClose = Objects.requireNonNull(onClose, "onClose must not be null"); } @Override - public Mono send(Publisher frames) { - return Flux.from(frames).flatMapSequential(this::sendOne).then(); + public void dispose() { + out.onComplete(); } @Override - public Mono sendOne(Frame frame) { - return Mono.fromRunnable(() -> out.onNext(frame)); + public boolean isDisposed() { + return out.isDisposed(); } @Override - public Flux receive() { - return in; + public Mono onClose() { + return onClose; } @Override - public Mono close() { - return Mono.defer( - () -> { - out.onComplete(); - closeNotifier.onComplete(); - return closeNotifier; - }); + public Flux receive() { + return in.transform( + Operators.lift( + (__, actual) -> new ByteBufReleaserOperator(actual, this))); } @Override - public Mono onClose() { - return closeNotifier; + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + out.tryEmitPrioritized(frame); + } else { + out.tryEmitNormal(frame); + } + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + out.tryEmitFinal(errorFrame); + } + + @Override + public ByteBufAllocator alloc() { + return allocator; + } + + @Override + public SocketAddress remoteAddress() { + return address; } @Override - public double availability() { - return closeNotifier.isDisposed() ? 0.0 : 1.0; + public String toString() { + return "LocalDuplexConnection{" + "address=" + address + "hash=" + hashCode() + '}'; + } + + static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + final LocalDuplexConnection parent; + + Subscription s; + + public ByteBufReleaserOperator( + CoreSubscriber actual, LocalDuplexConnection parent) { + this.actual = actual; + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + try { + actual.onNext(buf); + } finally { + buf.release(); + } + } + + @Override + public void onError(Throwable t) { + parent.out.onError(t); + actual.onError(t); + } + + @Override + public void onComplete() { + parent.out.onComplete(); + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + parent.out.onComplete(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java index 55976b1c7..975cb6793 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -1,39 +1,44 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2021 the original author or authors. * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ package io.rsocket.transport.local; import io.rsocket.Closeable; import io.rsocket.DuplexConnection; +import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; +import java.util.Objects; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import java.util.function.Consumer; +import java.util.stream.Collectors; +import reactor.core.Scannable; import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; -/** Local within process transport for RSocket. */ -public class LocalServerTransport implements ServerTransport { - private static final ConcurrentMap registry = - new ConcurrentHashMap<>(); +/** + * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} in the + * same JVM. + */ +public final class LocalServerTransport implements ServerTransport { - static ServerDuplexConnectionAcceptor findServer(String name) { - return registry.get(name); - } + private static final ConcurrentMap registry = + new ConcurrentHashMap<>(); private final String name; @@ -41,27 +46,23 @@ private LocalServerTransport(String name) { this.name = name; } + /** + * Creates an instance. + * + * @param name the name of this {@link ServerTransport} that clients will connect to + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ public static LocalServerTransport create(String name) { + Objects.requireNonNull(name, "name must not be null"); return new LocalServerTransport(name); } - public LocalClientTransport clientTransport() { - return LocalClientTransport.create(name); - } - - @Override - public Mono start(ConnectionAcceptor acceptor) { - return Mono.create( - sink -> { - ServerDuplexConnectionAcceptor serverDuplexConnectionAcceptor = - new ServerDuplexConnectionAcceptor(name, acceptor); - if (registry.putIfAbsent(name, serverDuplexConnectionAcceptor) != null) { - throw new IllegalStateException("name already registered: " + name); - } - sink.success(serverDuplexConnectionAcceptor); - }); - } - + /** + * Creates an instance with a random name. + * + * @return a new instance with a random name + */ public static LocalServerTransport createEphemeral() { return create(UUID.randomUUID().toString()); } @@ -69,47 +70,109 @@ public static LocalServerTransport createEphemeral() { /** * Remove an instance from the JVM registry. * - * @param name the local transport instance to free. + * @param name the local transport instance to free + * @throws NullPointerException if {@code name} is {@code null} */ public static void dispose(String name) { - registry.remove(name); + Objects.requireNonNull(name, "name must not be null"); + ServerCloseableAcceptor sca = registry.remove(name); + if (sca != null) { + sca.dispose(); + } + } + + /** + * Retrieves an instance of {@link ConnectionAcceptor} based on the name of its {@code + * LocalServerTransport}. Returns {@code null} if that server is not registered. + * + * @param name the name of the server to retrieve + * @return the server if it has been registered, {@code null} otherwise + * @throws NullPointerException if {@code name} is {@code null} + */ + static @Nullable ConnectionAcceptor findServer(String name) { + Objects.requireNonNull(name, "name must not be null"); + + return registry.get(name); } - public String getName() { + /** Return the name associated with this local server instance. */ + String getName() { return name; } - static class ServerDuplexConnectionAcceptor implements Consumer, Closeable { + /** + * Return a new {@link LocalClientTransport} connected to this {@code LocalServerTransport} + * through its {@link #getName()}. + */ + public LocalClientTransport clientTransport() { + return LocalClientTransport.create(name); + } + + @Override + public Mono start(ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor, "acceptor must not be null"); + return Mono.create( + sink -> { + ServerCloseableAcceptor closeable = new ServerCloseableAcceptor(name, acceptor); + if (registry.putIfAbsent(name, closeable) != null) { + sink.error(new IllegalStateException("name already registered: " + name)); + } + sink.success(closeable); + }); + } + + @SuppressWarnings({"ReactorTransformationOnMonoVoid", "CallingSubscribeInNonBlockingScope"}) + static class ServerCloseableAcceptor implements ConnectionAcceptor, Closeable { + private final LocalSocketAddress address; + private final ConnectionAcceptor acceptor; - private final MonoProcessor closeNotifier = MonoProcessor.create(); - public ServerDuplexConnectionAcceptor(String name, ConnectionAcceptor acceptor) { + private final Set activeConnections = ConcurrentHashMap.newKeySet(); + + private final Sinks.Empty onClose = Sinks.unsafe().empty(); + + ServerCloseableAcceptor(String name, ConnectionAcceptor acceptor) { + Objects.requireNonNull(name, "name must not be null"); this.address = new LocalSocketAddress(name); this.acceptor = acceptor; } @Override - public void accept(DuplexConnection duplexConnection) { - acceptor.apply(duplexConnection).subscribe(); + public Mono apply(DuplexConnection duplexConnection) { + activeConnections.add(duplexConnection); + duplexConnection + .onClose() + .doFinally(__ -> activeConnections.remove(duplexConnection)) + .subscribe(); + return acceptor.apply(duplexConnection); + } + + @Override + public void dispose() { + if (!registry.remove(address.getName(), this)) { + // already disposed + return; + } + + Mono.whenDelayError( + activeConnections + .stream() + .peek(DuplexConnection::dispose) + .map(DuplexConnection::onClose) + .collect(Collectors.toList())) + .subscribe(null, onClose::tryEmitError, onClose::tryEmitEmpty); } @Override - public Mono close() { - return Mono.defer( - () -> { - if (!registry.remove(address.getName(), this)) { - throw new AssertionError(); - } - - closeNotifier.onComplete(); - return Mono.empty(); - }); + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); } @Override public Mono onClose() { - return closeNotifier; + return onClose.asMono(); } } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java index 48aece758..4d0da126a 100644 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -17,22 +17,32 @@ package io.rsocket.transport.local; import java.net.SocketAddress; +import java.util.Objects; -public class LocalSocketAddress extends SocketAddress { +/** An implementation of {@link SocketAddress} representing a local connection. */ +public final class LocalSocketAddress extends SocketAddress { + + private static final long serialVersionUID = -7513338854585475473L; - private static final long serialVersionUID = -5974652906020342524L; private final String name; + /** + * Creates a new instance. + * + * @param name the name representing the address + * @throws NullPointerException if {@code name} is {@code null} + */ public LocalSocketAddress(String name) { - this.name = name; + this.name = Objects.requireNonNull(name, "name must not be null"); } + /** Return the name for this connection. */ public String getName() { return name; } @Override public String toString() { - return "[local server] " + name; + return "[local address] " + name; } } diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java deleted file mode 100644 index a24842982..000000000 --- a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalUriHandler.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Optional; - -public class LocalUriHandler implements UriHandler { - @Override - public Optional buildClient(URI uri) { - if ("local".equals(uri.getScheme())) { - return Optional.of(LocalClientTransport.create(uri.getSchemeSpecificPart())); - } - - return UriHandler.super.buildClient(uri); - } - - @Override - public Optional buildServer(URI uri) { - if ("local".equals(uri.getScheme())) { - return Optional.of(LocalServerTransport.create(uri.getSchemeSpecificPart())); - } - - return UriHandler.super.buildServer(uri); - } -} diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/package-info.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/package-info.java new file mode 100644 index 000000000..6a67f6af4 --- /dev/null +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The local RSocket transport implementation. */ +@NonNullApi +package io.rsocket.transport.local; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index 6b72bbd84..000000000 --- a/rsocket-transport-local/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.local.LocalUriHandler diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientSetupRule.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientSetupRule.java deleted file mode 100644 index f13256576..000000000 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientSetupRule.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import io.rsocket.Closeable; -import io.rsocket.test.ClientSetupRule; -import java.util.concurrent.atomic.AtomicInteger; - -public class LocalClientSetupRule extends ClientSetupRule { - private static final AtomicInteger uniqueNameGenerator = new AtomicInteger(); - - public LocalClientSetupRule() { - super( - () -> "test" + uniqueNameGenerator.incrementAndGet(), - (address, server) -> LocalClientTransport.create(address), - LocalServerTransport::create); - } -} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java new file mode 100644 index 000000000..095de3f0e --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.rsocket.Closeable; +import java.time.Duration; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +final class LocalClientTransportTest { + + @DisplayName("connects to server") + @Test + void connect() { + LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); + + Closeable closeable = + serverTransport.start(duplexConnection -> duplexConnection.receive().then()).block(); + + try { + LocalClientTransport.create(serverTransport.getName()) + .connect() + .doOnNext(d -> d.receive().subscribe()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + closeable.dispose(); + closeable.onClose().block(Duration.ofSeconds(5)); + } + } + + @DisplayName("generates error if server not started") + @Test + void connectNoServer() { + LocalClientTransport.create("test-name") + .connect() + .as(StepVerifier::create) + .verifyErrorMessage("Could not find server: test-name"); + } + + @DisplayName("creates client") + @Test + void create() { + assertThat(LocalClientTransport.create("test-name")).isNotNull(); + } + + @DisplayName("throws NullPointerException with null name") + @Test + void createNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalClientTransport.create(null)) + .withMessage("name must not be null"); + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java index f593f88c1..9228e2d05 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.transport.local; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.test.PingHandler; import java.time.Duration; @@ -26,23 +29,24 @@ public final class LocalPingPong { public static void main(String... args) { - RSocketFactory.receive() - .acceptor(new PingHandler()) - .transport(LocalServerTransport.create("test-local-server")) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(LocalServerTransport.create("test-local-server")) .block(); - Mono rSocketMono = - RSocketFactory.connect() - .transport(LocalClientTransport.create("test-local-server")) - .start(); + Mono client = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(LocalClientTransport.create("test-local-server")); - PingClient pingClient = new PingClient(rSocketMono); + PingClient pingClient = new PingClient(client); Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); - final int count = 1_000_000_000; + + int count = 1_000_000_000; + pingClient - .startPingPong(count, recorder) + .requestResponsePingPong(count, recorder) .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) .blockLast(); } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java new file mode 100644 index 000000000..28c1dacac --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..8ae16a0a5 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableWithFragmentationTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableWithFragmentationTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java index 746cdca88..e4edafc39 100644 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java @@ -1,21 +1,118 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.transport.local; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +final class LocalServerTransportTest { + + @DisplayName("create throws NullPointerException with null name") + @Test + void createNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.create(null)) + .withMessage("name must not be null"); + } + + @DisplayName("dispose removes name from registry") + @Test + void dispose() { + LocalServerTransport.dispose("test-name"); + } + + @DisplayName("dispose throws NullPointerException with null name") + @Test + void disposeNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.dispose(null)) + .withMessage("name must not be null"); + } + + @DisplayName("creates transports with ephemeral names") + @Test + void ephemeral() { + LocalServerTransport serverTransport1 = LocalServerTransport.createEphemeral(); + LocalServerTransport serverTransport2 = LocalServerTransport.createEphemeral(); + + assertThat(serverTransport1.getName()).isNotEqualTo(serverTransport2.getName()); + } + + @DisplayName("returns the server by name") + @Test + void findServer() { + LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); -import org.junit.Test; + assertThat(LocalServerTransport.findServer(serverTransport.getName())).isNotNull(); + } + + @DisplayName("returns null if server hasn't been started") + @Test + void findServerMissingName() { + assertThat(LocalServerTransport.findServer("test-name")).isNull(); + } + + @DisplayName("findServer throws NullPointerException with null name") + @Test + void findServerNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.findServer(null)) + .withMessage("name must not be null"); + } + + @DisplayName("creates transport with name") + @Test + void named() { + LocalServerTransport serverTransport = LocalServerTransport.create("test-name"); + + assertThat(serverTransport.getName()).isEqualTo("test-name"); + } -public class LocalServerTransportTest { + @DisplayName("starts local server transport") @Test - public void testEphemeral() { - LocalServerTransport st1 = LocalServerTransport.createEphemeral(); - LocalServerTransport st2 = LocalServerTransport.createEphemeral(); - assertNotEquals(st2.getName(), st1.getName()); + void start() { + LocalServerTransport ephemeral = LocalServerTransport.createEphemeral(); + try { + ephemeral + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + LocalServerTransport.dispose(ephemeral.getName()); + } } + @DisplayName("start throws NullPointerException with null acceptor") @Test - public void testNamed() { - LocalServerTransport st = LocalServerTransport.create("LocalServerTransportTest"); - assertEquals("LocalServerTransportTest", st.getName()); + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.createEphemeral().start(null)) + .withMessage("acceptor must not be null"); } } diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalSocketAddressTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalSocketAddressTest.java new file mode 100644 index 000000000..8ad7b70ce --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalSocketAddressTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class LocalSocketAddressTest { + + @DisplayName("constructor throws NullPointerException with null name") + @Test + void constructorNullName() { + assertThatNullPointerException() + .isThrownBy(() -> new LocalSocketAddress(null)) + .withMessage("name must not be null"); + } + + @DisplayName("returns the configured name") + @Test + void name() { + assertThat(new LocalSocketAddress("test-name").getName()).isEqualTo("test-name"); + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java new file mode 100644 index 000000000..87ad2105b --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> "LocalTransportTest-" + testInfo.getDisplayName() + "-" + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address)); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java new file mode 100644 index 000000000..3ca5f5911 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalTransportWithFragmentationTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalTransportWithFragmentationTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java deleted file mode 100644 index a7cbb8f4b..000000000 --- a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalUriTransportRegistryTest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.local; - -import static org.junit.Assert.assertTrue; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.Test; - -public class LocalUriTransportRegistryTest { - @Test - public void testLocalClient() { - ClientTransport transport = UriTransportRegistry.clientForUri("local:test1"); - - assertTrue(transport instanceof LocalClientTransport); - } - - @Test - public void testLocalServer() { - ServerTransport transport = UriTransportRegistry.serverForUri("local:test1"); - - assertTrue(transport instanceof LocalServerTransport); - } -} diff --git a/rsocket-transport-local/src/test/resources/log4j.properties b/rsocket-transport-local/src/test/resources/log4j.properties deleted file mode 100644 index e1edb1274..000000000 --- a/rsocket-transport-local/src/test/resources/log4j.properties +++ /dev/null @@ -1,17 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -#

-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -#

-# http://www.apache.org/licenses/LICENSE-2.0 -#

-# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] (%F:%L) - %m%n \ No newline at end of file diff --git a/rsocket-transport-local/src/test/resources/logback-test.xml b/rsocket-transport-local/src/test/resources/logback-test.xml new file mode 100644 index 000000000..5c92235c2 --- /dev/null +++ b/rsocket-transport-local/src/test/resources/logback-test.xml @@ -0,0 +1,49 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + ./test-out.log + false + + %-5relative %-5level %logger{35} - %msg%n + + + + + + + + + + + + + + + + + + + diff --git a/rsocket-transport-netty/build.gradle b/rsocket-transport-netty/build.gradle index 1def42fb7..39a5ceac5 100644 --- a/rsocket-transport-netty/build.gradle +++ b/rsocket-transport-netty/build.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -14,13 +14,47 @@ * limitations under the License. */ +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' + id "com.google.osdetector" version "1.4.0" +} + +def os_suffix = "" +if (osdetector.classifier in ["linux-x86_64", "linux-aarch_64", "osx-x86_64", "osx-aarch_64", "windows-x86_64"]) { + os_suffix = "::" + osdetector.classifier +} + dependencies { - compile project(':rsocket-core') - compile "io.projectreactor.ipc:reactor-netty:0.7.0.RELEASE" - compile "io.netty:netty-handler:4.1.15.Final" - compile "io.netty:netty-handler-proxy:4.1.15.Final" - compile "io.netty:netty-codec-http:4.1.15.Final" - compile "io.netty:netty-transport-native-epoll:4.1.15.Final" - - testCompile project(':rsocket-test') + api project(':rsocket-core') + api "io.projectreactor.netty:reactor-netty-core" + api "io.projectreactor.netty:reactor-netty-http" + api 'org.slf4j:slf4j-api' + + testImplementation project(':rsocket-test') + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.mockito:mockito-core' + testImplementation 'org.mockito:mockito-junit-jupiter' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' + + testRuntimeOnly 'org.bouncycastle:bcpkix-jdk15on' + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' + testRuntimeOnly 'io.netty:netty-tcnative-boringssl-static' + os_suffix } + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.netty") + } +} + +test { + minHeapSize = "512m" + maxHeapSize = "4096m" +} + +description = 'Reactor Netty RSocket transport implementations (TCP, Websocket)' diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/NettyDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/NettyDuplexConnection.java deleted file mode 100644 index 2afd85160..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/NettyDuplexConnection.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.transport.netty; - -import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import org.reactivestreams.Publisher; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.ipc.netty.NettyContext; -import reactor.ipc.netty.NettyInbound; -import reactor.ipc.netty.NettyOutbound; - -public class NettyDuplexConnection implements DuplexConnection { - private final NettyInbound in; - private final NettyOutbound out; - private final NettyContext context; - - public NettyDuplexConnection(NettyInbound in, NettyOutbound out, NettyContext context) { - this.in = in; - this.out = out; - this.context = context; - } - - @Override - public Mono send(Publisher frames) { - return Flux.from(frames).concatMap(this::sendOne).then(); - } - - @Override - public Mono sendOne(Frame frame) { - return out.sendObject(frame.content()).then(); - } - - @Override - public Flux receive() { - return in.receive().map(buf -> Frame.from(buf.retain())); - } - - @Override - public Mono close() { - return Mono.fromRunnable( - () -> { - context.dispose(); - context.channel().close(); - }); - } - - @Override - public Mono onClose() { - return context.onClose(); - } - - @Override - public double availability() { - return context.isDisposed() ? 0.0 : 1.0; - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java index 1a4e598fe..d7b368a3e 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,16 +16,31 @@ package io.rsocket.transport.netty; -import static io.rsocket.frame.FrameHeaderFlyweight.FRAME_LENGTH_MASK; -import static io.rsocket.frame.FrameHeaderFlyweight.FRAME_LENGTH_SIZE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.LengthFieldBasedFrameDecoder; -public class RSocketLengthCodec extends LengthFieldBasedFrameDecoder { +/** + * An extension to the Netty {@link LengthFieldBasedFrameDecoder} that encapsulates the + * RSocket-specific frame length header details. + */ +public final class RSocketLengthCodec extends LengthFieldBasedFrameDecoder { + + /** Creates a new instance of the decoder, specifying the RSocket frame length header size. */ public RSocketLengthCodec() { - super(FRAME_LENGTH_MASK, 0, FRAME_LENGTH_SIZE, 0, 0); + this(FRAME_LENGTH_MASK); + } + + /** + * Creates a new instance of the decoder, specifying the RSocket frame length header size. + * + * @param maxFrameLength maximum allowed frame length for incoming rsocket frames + */ + public RSocketLengthCodec(int maxFrameLength) { + super(maxFrameLength, 0, FRAME_LENGTH_SIZE, 0, 0); } /** diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java new file mode 100644 index 000000000..f5d36269c --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -0,0 +1,98 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; +import java.util.Objects; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; + +/** An implementation of {@link DuplexConnection} that connects via TCP. */ +public final class TcpDuplexConnection extends BaseDuplexConnection { + private final String side; + private final Connection connection; + + /** + * Creates a new instance + * + * @param connection the {@link Connection} for managing the server + */ + public TcpDuplexConnection(Connection connection) { + this("unknown", connection); + } + + /** + * Creates a new instance + * + * @param connection the {@link Connection} for managing the server + */ + public TcpDuplexConnection(String side, Connection connection) { + this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; + + connection.outbound().send(sender).then().doFinally(__ -> connection.dispose()).subscribe(); + } + + @Override + public ByteBufAllocator alloc() { + return connection.channel().alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); + } + + @Override + protected void doOnClose() { + connection.dispose(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(FrameLengthCodec.encode(alloc(), errorFrame.readableBytes(), errorFrame)); + } + + @Override + public Flux receive() { + return connection.inbound().receive().map(FrameLengthCodec::frame); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + super.sendFrame(streamId, FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame)); + } + + @Override + public String toString() { + return "TcpDuplexConnection{" + "side='" + side + '\'' + ", connection=" + connection + '}'; + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java deleted file mode 100644 index bc7e24269..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpUriHandler.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Optional; -import reactor.ipc.netty.tcp.TcpServer; - -public class TcpUriHandler implements UriHandler { - @Override - public Optional buildClient(URI uri) { - if ("tcp".equals(uri.getScheme())) { - return Optional.of(TcpClientTransport.create(uri.getHost(), uri.getPort())); - } - - return UriHandler.super.buildClient(uri); - } - - @Override - public Optional buildServer(URI uri) { - if ("tcp".equals(uri.getScheme())) { - return Optional.of(TcpServerTransport.create(TcpServer.create(uri.getHost(), uri.getPort()))); - } - - return UriHandler.super.buildServer(uri); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java index 754f3feff..8f1170c5b 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,81 +15,95 @@ */ package io.rsocket.transport.netty; -import static io.netty.buffer.Unpooled.wrappedBuffer; -import static io.rsocket.frame.FrameHeaderFlyweight.FRAME_LENGTH_SIZE; - import io.netty.buffer.ByteBuf; -import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.ByteBufAllocator; import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; import io.rsocket.DuplexConnection; -import io.rsocket.Frame; -import io.rsocket.frame.FrameHeaderFlyweight; -import org.reactivestreams.Publisher; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; +import java.util.Objects; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import reactor.ipc.netty.NettyContext; -import reactor.ipc.netty.NettyInbound; -import reactor.ipc.netty.NettyOutbound; +import reactor.netty.Connection; /** - * Implementation of a DuplexConnection for Websocket. + * An implementation of {@link DuplexConnection} that connects via a Websocket. * - *

rsocket-java strongly assumes that each Frame is encoded with the length. This is not true for - * message oriented transports so this must be specifically dropped from Frames sent and stitched - * back on for frames received. + *

rsocket-java strongly assumes that each ByteBuf is encoded with the length. This is not true + * for message oriented transports so this must be specifically dropped from Frames sent and + * stitched back on for frames received. */ -public class WebsocketDuplexConnection implements DuplexConnection { - private final NettyInbound in; - private final NettyOutbound out; - private final NettyContext context; +public final class WebsocketDuplexConnection extends BaseDuplexConnection { + private final String side; + private final Connection connection; - public WebsocketDuplexConnection(NettyInbound in, NettyOutbound out, NettyContext context) { - this.in = in; - this.out = out; - this.context = context; + /** + * Creates a new instance + * + * @param connection the {@link Connection} to for managing the server + */ + public WebsocketDuplexConnection(Connection connection) { + this("unknown", connection); } - @Override - public Mono send(Publisher frames) { - return Flux.from(frames).concatMap(this::sendOne).then(); + /** + * Creates a new instance + * + * @param connection the {@link Connection} to for managing the server + */ + public WebsocketDuplexConnection(String side, Connection connection) { + this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; + + connection + .outbound() + .sendObject(sender.map(BinaryWebSocketFrame::new)) + .then() + .doFinally(__ -> connection.dispose()) + .subscribe(); } @Override - public Mono sendOne(Frame frame) { - return out.sendObject(new BinaryWebSocketFrame(frame.content().skipBytes(FRAME_LENGTH_SIZE))) - .then(); + public ByteBufAllocator alloc() { + return connection.channel().alloc(); } @Override - public Flux receive() { - return in.receive() - .map( - buf -> { - CompositeByteBuf composite = context.channel().alloc().compositeBuffer(); - ByteBuf length = wrappedBuffer(new byte[FRAME_LENGTH_SIZE]); - FrameHeaderFlyweight.encodeLength(length, 0, buf.readableBytes()); - composite.addComponents(true, length, buf.retain()); - return Frame.from(composite); - }); + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); } @Override - public Mono close() { - return Mono.fromRunnable( - () -> { - if (!context.isDisposed()) { - context.channel().close(); - } - }); + protected void doOnClose() { + connection.dispose(); } @Override public Mono onClose() { - return context.onClose(); + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); + } + + @Override + public Flux receive() { + return connection.inbound().receive(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(errorFrame); } @Override - public double availability() { - return context.isDisposed() ? 0.0 : 1.0; + public String toString() { + return "WebsocketDuplexConnection{" + + "side='" + + side + + '\'' + + ", connection=" + + connection + + '}'; } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java deleted file mode 100644 index dcb3f295b..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketUriHandler.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriHandler; -import java.net.URI; -import java.util.Optional; - -public class WebsocketUriHandler implements UriHandler { - @Override - public Optional buildClient(URI uri) { - if ("ws".equals(uri.getScheme()) || "wss".equals(uri.getScheme())) { - return Optional.of(WebsocketClientTransport.create(uri)); - } - - return UriHandler.super.buildClient(uri); - } - - @Override - public Optional buildServer(URI uri) { - if ("ws".equals(uri.getScheme())) { - return Optional.of( - WebsocketServerTransport.create( - uri.getHost(), WebsocketClientTransport.getPort(uri, 80))); - } - - return UriHandler.super.buildServer(uri); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java index 8a17b9b09..84214b98c 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,57 +16,106 @@ package io.rsocket.transport.netty.client; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + import io.rsocket.DuplexConnection; import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.netty.NettyDuplexConnection; +import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.RSocketLengthCodec; +import io.rsocket.transport.netty.TcpDuplexConnection; import java.net.InetSocketAddress; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import java.util.Objects; import reactor.core.publisher.Mono; -import reactor.ipc.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpClient; + +/** + * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} via TCP. + */ +public final class TcpClientTransport implements ClientTransport { -public class TcpClientTransport implements ClientTransport { - private final Logger logger = LoggerFactory.getLogger(TcpClientTransport.class); private final TcpClient client; + private final int maxFrameLength; - private TcpClientTransport(TcpClient client) { + private TcpClientTransport(TcpClient client, int maxFrameLength) { this.client = client; + this.maxFrameLength = maxFrameLength; } + /** + * Creates a new instance connecting to localhost + * + * @param port the port to connect to + * @return a new instance + */ public static TcpClientTransport create(int port) { - TcpClient tcpClient = TcpClient.create(port); + TcpClient tcpClient = TcpClient.create().port(port); return create(tcpClient); } + /** + * Creates a new instance + * + * @param bindAddress the address to connect to + * @param port the port to connect to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ public static TcpClientTransport create(String bindAddress, int port) { - TcpClient tcpClient = TcpClient.create(bindAddress, port); + Objects.requireNonNull(bindAddress, "bindAddress must not be null"); + + TcpClient tcpClient = TcpClient.create().host(bindAddress).port(port); return create(tcpClient); } + /** + * Creates a new instance + * + * @param address the address to connect to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ public static TcpClientTransport create(InetSocketAddress address) { - TcpClient tcpClient = TcpClient.create(address.getHostString(), address.getPort()); + Objects.requireNonNull(address, "address must not be null"); + + TcpClient tcpClient = TcpClient.create().remoteAddress(() -> address); return create(tcpClient); } + /** + * Creates a new instance + * + * @param client the {@link TcpClient} to use + * @return a new instance + * @throws NullPointerException if {@code client} is {@code null} + */ public static TcpClientTransport create(TcpClient client) { - return new TcpClientTransport(client); + return create(client, FRAME_LENGTH_MASK); + } + + /** + * Creates a new instance + * + * @param client the {@link TcpClient} to use + * @param maxFrameLength max frame length being sent over the connection + * @return a new instance + * @throws NullPointerException if {@code client} is {@code null} + */ + public static TcpClientTransport create(TcpClient client, int maxFrameLength) { + Objects.requireNonNull(client, "client must not be null"); + + return new TcpClientTransport(client, maxFrameLength); + } + + @Override + public int maxFrameLength() { + return maxFrameLength; } @Override public Mono connect() { - return Mono.create( - sink -> - client - .newHandler( - (in, out) -> { - in.context().addHandler("client-length-codec", new RSocketLengthCodec()); - NettyDuplexConnection connection = - new NettyDuplexConnection(in, out, in.context()); - sink.success(connection); - return connection.onClose(); - }) - .doOnError(sink::error) - .subscribe()); + return client + .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec(maxFrameLength))) + .connect() + .map(connection -> new TcpDuplexConnection("client", connection)); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java index 55b5b6197..86be47893 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,97 +16,162 @@ package io.rsocket.transport.netty.client; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; import io.rsocket.DuplexConnection; import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.TransportHeaderAware; +import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; import java.net.InetSocketAddress; import java.net.URI; -import java.util.Collections; -import java.util.Map; -import java.util.function.Supplier; +import java.util.Arrays; +import java.util.Objects; +import java.util.function.Consumer; import reactor.core.publisher.Mono; -import reactor.ipc.netty.http.client.HttpClient; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.client.WebsocketClientSpec; +import reactor.netty.tcp.TcpClient; + +/** + * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} over + * WebSocket. + */ +public final class WebsocketClientTransport implements ClientTransport { + + private static final String DEFAULT_PATH = "/"; -public class WebsocketClientTransport implements ClientTransport, TransportHeaderAware { private final HttpClient client; - private String path; - private Supplier> transportHeaders = Collections::emptyMap; + + private final String path; + + private HttpHeaders headers = new DefaultHttpHeaders(); + + private final WebsocketClientSpec.Builder specBuilder = + WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); private WebsocketClientTransport(HttpClient client, String path) { + Objects.requireNonNull(client, "HttpClient must not be null"); + Objects.requireNonNull(path, "path must not be null"); this.client = client; - this.path = path; + this.path = path.startsWith("/") ? path : "/" + path; } + /** + * Creates a new instance connecting to localhost + * + * @param port the port to connect to + * @return a new instance + */ public static WebsocketClientTransport create(int port) { - HttpClient httpClient = HttpClient.create(port); - return create(httpClient, "/"); + return create(TcpClient.create().port(port)); } + /** + * Creates a new instance + * + * @param bindAddress the address to connect to + * @param port the port to connect to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ public static WebsocketClientTransport create(String bindAddress, int port) { - HttpClient httpClient = HttpClient.create(bindAddress, port); - return create(httpClient, "/"); + return create(TcpClient.create().host(bindAddress).port(port)); } + /** + * Creates a new instance + * + * @param address the address to connect to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ public static WebsocketClientTransport create(InetSocketAddress address) { - return create(address.getHostName(), address.getPort()); + Objects.requireNonNull(address, "address must not be null"); + return create(TcpClient.create().remoteAddress(() -> address)); } - public static WebsocketClientTransport create(URI uri) { - HttpClient httpClient = createClient(uri); - return create(httpClient, uri.toString()); + /** + * Creates a new instance + * + * @param client the {@link TcpClient} to use + * @return a new instance + * @throws NullPointerException if {@code client} or {@code path} is {@code null} + */ + public static WebsocketClientTransport create(TcpClient client) { + return new WebsocketClientTransport(HttpClient.from(client), DEFAULT_PATH); } - private static HttpClient createClient(URI uri) { - if (isSecureWebsocket(uri)) { - return HttpClient.create( - options -> - options - .sslSupport() - .connectAddress( - () -> InetSocketAddress.createUnresolved(uri.getHost(), getPort(uri, 443)))); - } else { - return HttpClient.create(uri.getHost(), getPort(uri, 80)); - } - } - - public static int getPort(URI uri, int defaultPort) { - return uri.getPort() == -1 ? defaultPort : uri.getPort(); + /** + * Creates a new instance + * + * @param uri the URI to connect to + * @return a new instance + * @throws NullPointerException if {@code uri} is {@code null} + */ + public static WebsocketClientTransport create(URI uri) { + Objects.requireNonNull(uri, "uri must not be null"); + boolean isSecure = uri.getScheme().equals("wss") || uri.getScheme().equals("https"); + TcpClient client = + (isSecure ? TcpClient.create().secure() : TcpClient.create()) + .host(uri.getHost()) + .port(uri.getPort() == -1 ? (isSecure ? 443 : 80) : uri.getPort()); + return new WebsocketClientTransport(HttpClient.from(client), uri.getPath()); } - public static boolean isSecureWebsocket(URI uri) { - return uri.getScheme().equals("wss") || uri.getScheme().equals("https"); + /** + * Creates a new instance + * + * @param client the {@link HttpClient} to use + * @param path the path to request + * @return a new instance + * @throws NullPointerException if {@code client} or {@code path} is {@code null} + */ + public static WebsocketClientTransport create(HttpClient client, String path) { + return new WebsocketClientTransport(client, path); } - public static boolean isPlaintextWebsocket(URI uri) { - return uri.getScheme().equals("ws") || uri.getScheme().equals("http"); + /** + * Add a header and value(s) to use for the WebSocket handshake request. + * + * @param name the header name + * @param values the header value(s) + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketClientTransport header(String name, String... values) { + if (values != null) { + Arrays.stream(values).forEach(value -> headers.add(name, value)); + } + return this; } - public static WebsocketClientTransport create(HttpClient client, String path) { - return new WebsocketClientTransport(client, path); + /** + * Provide a consumer to customize properties of the {@link WebsocketClientSpec} to use for + * WebSocket upgrades. The consumer is invoked immediately. + * + * @param configurer the configurer to apply to the spec + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketClientTransport webSocketSpec(Consumer configurer) { + configurer.accept(specBuilder); + return this; } @Override - public Mono connect() { - return Mono.create( - sink -> - client - .ws(path, hb -> transportHeaders.get().forEach(hb::set)) - .flatMap( - response -> - response.receiveWebsocket( - (in, out) -> { - WebsocketDuplexConnection connection = - new WebsocketDuplexConnection(in, out, in.context()); - sink.success(connection); - return connection.onClose(); - })) - .doOnError(sink::error) - .subscribe()); + public int maxFrameLength() { + return specBuilder.build().maxFramePayloadLength(); } @Override - public void setTransportHeaders(Supplier> transportHeaders) { - this.transportHeaders = transportHeaders; + public Mono connect() { + return client + .headers(headers -> headers.add(this.headers)) + .websocket(specBuilder.build()) + .uri(path) + .connect() + .map(connection -> new WebsocketDuplexConnection("client", connection)); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/package-info.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/package-info.java new file mode 100644 index 000000000..4567f2012 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The Netty-based RSocket client transport implementations. */ +@NonNullApi +package io.rsocket.transport.netty.client; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/package-info.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/package-info.java new file mode 100644 index 000000000..599500cff --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The Netty-based RSocket transport implementations. */ +@NonNullApi +package io.rsocket.transport.netty; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java new file mode 100644 index 000000000..33cff28b4 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java @@ -0,0 +1,64 @@ +package io.rsocket.transport.netty.server; + +import static io.netty.channel.ChannelHandler.*; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.transport.ServerTransport; +import java.util.function.Consumer; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.WebsocketServerSpec; + +abstract class BaseWebsocketServerTransport< + SELF extends BaseWebsocketServerTransport, T extends Closeable> + implements ServerTransport { + private static final Logger logger = LoggerFactory.getLogger(BaseWebsocketServerTransport.class); + private static final ChannelHandler pongHandler = new PongHandler(); + + static Function serverConfigurer = + server -> server.doOnConnection(connection -> connection.addHandlerLast(pongHandler)); + + final WebsocketServerSpec.Builder specBuilder = + WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); + + /** + * Provide a consumer to customize properties of the {@link WebsocketServerSpec} to use for + * WebSocket upgrades. The consumer is invoked immediately. + * + * @param configurer the configurer to apply to the spec + * @return the same instance for method chaining + * @since 1.0.1 + */ + @SuppressWarnings("unchecked") + public SELF webSocketSpec(Consumer configurer) { + configurer.accept(specBuilder); + return (SELF) this; + } + + @Override + public int maxFrameLength() { + return specBuilder.build().maxFramePayloadLength(); + } + + @Sharable + private static class PongHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof PongWebSocketFrame) { + logger.debug("received WebSocket Pong Frame"); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + ctx.fireChannelRead(msg); + } + } + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java new file mode 100644 index 000000000..7e98905ff --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java @@ -0,0 +1,87 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import io.rsocket.Closeable; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.util.Objects; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableChannel; + +/** + * An implementation of {@link Closeable} that wraps a {@link DisposableChannel}, enabling + * close-ability and exposing the {@link DisposableChannel}'s address. + */ +public final class CloseableChannel implements Closeable { + + /** For forward compatibility: remove when RSocket compiles against Reactor 1.0. */ + private static final Method channelAddressMethod; + + static { + try { + channelAddressMethod = DisposableChannel.class.getMethod("address"); + } catch (NoSuchMethodException ex) { + throw new IllegalStateException("Expected address method", ex); + } + } + + private final DisposableChannel channel; + + /** + * Creates a new instance + * + * @param channel the {@link DisposableChannel} to wrap + * @throws NullPointerException if {@code context} is {@code null} + */ + CloseableChannel(DisposableChannel channel) { + this.channel = Objects.requireNonNull(channel, "channel must not be null"); + } + + /** + * Return local server selector channel address. + * + * @return local {@link InetSocketAddress} + * @see DisposableChannel#address() + */ + public InetSocketAddress address() { + try { + return (InetSocketAddress) channel.address(); + } catch (ClassCastException | NoSuchMethodError e) { + try { + return (InetSocketAddress) channelAddressMethod.invoke(this.channel); + } catch (Exception ex) { + throw new IllegalStateException("Unable to obtain address", ex); + } + } + } + + @Override + public void dispose() { + channel.dispose(); + } + + @Override + public boolean isDisposed() { + return channel.isDisposed(); + } + + @Override + public Mono onClose() { + return channel.onDispose(); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/NettyContextCloseable.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/NettyContextCloseable.java deleted file mode 100644 index 620dd4c7a..000000000 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/NettyContextCloseable.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty.server; - -import io.rsocket.Closeable; -import java.net.InetSocketAddress; -import reactor.core.publisher.Mono; -import reactor.core.publisher.MonoProcessor; -import reactor.ipc.netty.NettyContext; - -/** - * A {@link Closeable} wrapping a {@link NettyContext}, allowing for close and aware of its address. - */ -public class NettyContextCloseable implements Closeable { - private NettyContext context; - - private MonoProcessor onClose; - - NettyContextCloseable(NettyContext context) { - this.onClose = MonoProcessor.create(); - this.context = context; - } - - @Override - public Mono close() { - return Mono.empty() - .doFinally( - s -> { - context.dispose(); - onClose.onComplete(); - }) - .then(); - } - - @Override - public Mono onClose() { - return onClose; - } - - /** - * @see NettyContext#address() - * @return socket address. - */ - public InetSocketAddress address() { - return context.address(); - } -} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java index 36943691f..32562c4a4 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,50 +16,109 @@ package io.rsocket.transport.netty.server; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.NettyDuplexConnection; import io.rsocket.transport.netty.RSocketLengthCodec; +import io.rsocket.transport.netty.TcpDuplexConnection; import java.net.InetSocketAddress; +import java.util.Objects; import reactor.core.publisher.Mono; -import reactor.ipc.netty.tcp.TcpServer; +import reactor.netty.tcp.TcpServer; + +/** + * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via TCP. + */ +public final class TcpServerTransport implements ServerTransport { -public class TcpServerTransport implements ServerTransport { - TcpServer server; + private final TcpServer server; + private final int maxFrameLength; - private TcpServerTransport(TcpServer server) { + private TcpServerTransport(TcpServer server, int maxFrameLength) { this.server = server; + this.maxFrameLength = maxFrameLength; } - public static TcpServerTransport create(InetSocketAddress address) { - TcpServer server = TcpServer.create(address.getHostName(), address.getPort()); + /** + * Creates a new instance binding to localhost + * + * @param port the port to bind to + * @return a new instance + */ + public static TcpServerTransport create(int port) { + TcpServer server = TcpServer.create().port(port); return create(server); } + /** + * Creates a new instance + * + * @param bindAddress the address to bind to + * @param port the port to bind to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ public static TcpServerTransport create(String bindAddress, int port) { - TcpServer server = TcpServer.create(bindAddress, port); + Objects.requireNonNull(bindAddress, "bindAddress must not be null"); + TcpServer server = TcpServer.create().host(bindAddress).port(port); return create(server); } - public static TcpServerTransport create(int port) { - TcpServer server = TcpServer.create(port); - return create(server); + /** + * Creates a new instance + * + * @param address the address to bind to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ + public static TcpServerTransport create(InetSocketAddress address) { + Objects.requireNonNull(address, "address must not be null"); + return create(address.getHostName(), address.getPort()); } + /** + * Creates a new instance + * + * @param server the {@link TcpServer} to use + * @return a new instance + * @throws NullPointerException if {@code server} is {@code null} + */ public static TcpServerTransport create(TcpServer server) { - return new TcpServerTransport(server); + return create(server, FRAME_LENGTH_MASK); + } + + /** + * Creates a new instance + * + * @param server the {@link TcpServer} to use + * @param maxFrameLength max frame length being sent over the connection + * @return a new instance + * @throws NullPointerException if {@code server} is {@code null} + */ + public static TcpServerTransport create(TcpServer server, int maxFrameLength) { + Objects.requireNonNull(server, "server must not be null"); + return new TcpServerTransport(server, maxFrameLength); } @Override - public Mono start(ConnectionAcceptor acceptor) { - return server - .newHandler( - (in, out) -> { - in.context().addHandler("server-length-codec", new RSocketLengthCodec()); - NettyDuplexConnection connection = new NettyDuplexConnection(in, out, in.context()); - acceptor.apply(connection).subscribe(); + public int maxFrameLength() { + return maxFrameLength; + } - return out.neverComplete(); + @Override + public Mono start(ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor, "acceptor must not be null"); + return server + .doOnConnection( + c -> { + c.addHandlerLast(new RSocketLengthCodec(maxFrameLength)); + acceptor + .apply(new TcpDuplexConnection("server", c)) + .then(Mono.never()) + .subscribe(c.disposeSubscriber()); }) - .map(NettyContextCloseable::new); + .bind() + .map(CloseableChannel::new); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java index 8d5db9bcf..db13720e7 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -1,46 +1,87 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.rsocket.transport.netty.server; import io.rsocket.Closeable; import io.rsocket.transport.ServerTransport; import io.rsocket.transport.netty.WebsocketDuplexConnection; -import io.rsocket.util.CloseableAdapter; +import java.util.Objects; import java.util.function.BiFunction; +import java.util.function.Consumer; import org.reactivestreams.Publisher; import reactor.core.publisher.Mono; -import reactor.ipc.netty.http.server.HttpServerRoutes; -import reactor.ipc.netty.http.websocket.WebsocketInbound; -import reactor.ipc.netty.http.websocket.WebsocketOutbound; +import reactor.netty.Connection; +import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.HttpServerRoutes; +import reactor.netty.http.websocket.WebsocketInbound; +import reactor.netty.http.websocket.WebsocketOutbound; + +/** + * An implementation of {@link ServerTransport} that connects via Websocket and listens on specified + * routes. + */ +public final class WebsocketRouteTransport + extends BaseWebsocketServerTransport { + + private final String path; -public class WebsocketRouteTransport implements ServerTransport { - private HttpServerRoutes routes; - private String path; + private final Consumer routesBuilder; - public WebsocketRouteTransport(HttpServerRoutes routes, String path) { - this.routes = routes; - this.path = path; + private final HttpServer server; + + /** + * Creates a new instance + * + * @param server the {@link HttpServer} to use + * @param routesBuilder the builder for the routes that will be listened on + * @param path the path foe each route + */ + public WebsocketRouteTransport( + HttpServer server, Consumer routesBuilder, String path) { + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); + this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null"); + this.path = Objects.requireNonNull(path, "path must not be null"); } @Override public Mono start(ConnectionAcceptor acceptor) { - return Mono.defer( - () -> { - routes.ws(path, newHandler(acceptor)); - - return Mono.just( - new CloseableAdapter( - () -> { - // TODO close route somehow - })); - }); + Objects.requireNonNull(acceptor, "acceptor must not be null"); + return server + .route( + routes -> { + routesBuilder.accept(routes); + routes.ws(path, newHandler(acceptor), specBuilder.build()); + }) + .bind() + .map(CloseableChannel::new); } + /** + * Creates a new Websocket handler + * + * @param acceptor the {@link ConnectionAcceptor} to use with the handler + * @return a new Websocket handler + * @throws NullPointerException if {@code acceptor} is {@code null} + */ public static BiFunction> newHandler( ConnectionAcceptor acceptor) { - return (in, out) -> { - WebsocketDuplexConnection connection = new WebsocketDuplexConnection(in, out, in.context()); - acceptor.apply(connection).subscribe(); - - return out.neverComplete(); - }; + return (in, out) -> + acceptor + .apply(new WebsocketDuplexConnection("server", (Connection) in)) + .then(out.neverComplete()); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java index c208a3ee5..4fe736fad 100644 --- a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -16,50 +16,112 @@ package io.rsocket.transport.netty.server; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; +import io.rsocket.transport.ClientTransport; import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.TransportHeaderAware; -import java.util.Collections; -import java.util.Map; -import java.util.function.Supplier; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Objects; import reactor.core.publisher.Mono; -import reactor.ipc.netty.http.server.HttpServer; +import reactor.netty.Connection; +import reactor.netty.http.server.HttpServer; -public class WebsocketServerTransport - implements ServerTransport, TransportHeaderAware { - HttpServer server; - private Supplier> transportHeaders = Collections::emptyMap; +/** + * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a + * Websocket. + */ +public final class WebsocketServerTransport + extends BaseWebsocketServerTransport { + + private final HttpServer server; + + private HttpHeaders headers = new DefaultHttpHeaders(); private WebsocketServerTransport(HttpServer server) { - this.server = server; + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); } - public static WebsocketServerTransport create(String bindAddress, int port) { - HttpServer httpServer = HttpServer.create(bindAddress, port); + /** + * Creates a new instance binding to localhost + * + * @param port the port to bind to + * @return a new instance + */ + public static WebsocketServerTransport create(int port) { + HttpServer httpServer = HttpServer.create().port(port); return create(httpServer); } - public static WebsocketServerTransport create(int port) { - HttpServer httpServer = HttpServer.create(port); + /** + * Creates a new instance + * + * @param bindAddress the address to bind to + * @param port the port to bind to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ + public static WebsocketServerTransport create(String bindAddress, int port) { + Objects.requireNonNull(bindAddress, "bindAddress must not be null"); + HttpServer httpServer = HttpServer.create().host(bindAddress).port(port); return create(httpServer); } - public static WebsocketServerTransport create(HttpServer server) { + /** + * Creates a new instance + * + * @param address the address to bind to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ + public static WebsocketServerTransport create(InetSocketAddress address) { + Objects.requireNonNull(address, "address must not be null"); + return create(address.getHostName(), address.getPort()); + } + + /** + * Creates a new instance + * + * @param server the {@link HttpServer} to use + * @return a new instance + * @throws NullPointerException if {@code server} is {@code null} + */ + public static WebsocketServerTransport create(final HttpServer server) { + Objects.requireNonNull(server, "server must not be null"); return new WebsocketServerTransport(server); } + /** + * Add a header and value(s) to set on the response of WebSocket handshakes. + * + * @param name the header name + * @param values the header value(s) + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketServerTransport header(String name, String... values) { + if (values != null) { + Arrays.stream(values).forEach(value -> headers.add(name, value)); + } + return this; + } + @Override - public Mono start(ServerTransport.ConnectionAcceptor acceptor) { + public Mono start(ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor, "acceptor must not be null"); return server - .newHandler( + .handle( (request, response) -> { - transportHeaders.get().forEach(response::addHeader); - return response.sendWebsocket(WebsocketRouteTransport.newHandler(acceptor)); + response.headers(headers); + return response.sendWebsocket( + (in, out) -> + acceptor + .apply(new WebsocketDuplexConnection("server", (Connection) in)) + .then(out.neverComplete()), + specBuilder.build()); }) - .map(NettyContextCloseable::new); - } - - @Override - public void setTransportHeaders(Supplier> transportHeaders) { - this.transportHeaders = transportHeaders; + .bind() + .map(CloseableChannel::new); } } diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/package-info.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/package-info.java new file mode 100644 index 000000000..031844d06 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The Netty-based RSocket server transport implementations. */ +@NonNullApi +package io.rsocket.transport.netty.server; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json new file mode 100644 index 000000000..3a2baa440 --- /dev/null +++ b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json @@ -0,0 +1,16 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.RSocketLengthCodec" + }, + "name": "io.rsocket.transport.netty.RSocketLengthCodec", + "queryAllPublicMethods": true + }, + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler" + }, + "name": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler", + "queryAllPublicMethods": true + } +] \ No newline at end of file diff --git a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler b/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler deleted file mode 100644 index d762aba6a..000000000 --- a/rsocket-transport-netty/src/main/resources/META-INF/services/io.rsocket.uri.UriHandler +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -io.rsocket.transport.netty.TcpUriHandler -io.rsocket.transport.netty.WebsocketUriHandler diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java new file mode 100644 index 000000000..23041ec65 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java @@ -0,0 +1,184 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.RSocketProxy; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class FragmentTest { + private RSocket handler; + private CloseableChannel server; + private String message = null; + private String metaData = null; + private String responseMessage = null; + + private static Stream cases() { + return Stream.of(Arguments.of(0, 64), Arguments.of(64, 0), Arguments.of(64, 64)); + } + + public void startup(int frameSize) { + int randomPort = ThreadLocalRandom.current().nextInt(10_000, 20_000); + StringBuilder message = new StringBuilder(); + StringBuilder responseMessage = new StringBuilder(); + StringBuilder metaData = new StringBuilder(); + for (int i = 0; i < 100; i++) { + message.append("REQUEST "); + responseMessage.append("RESPONSE "); + metaData.append("METADATA "); + } + this.message = message.toString(); + this.responseMessage = responseMessage.toString(); + this.metaData = metaData.toString(); + + TcpServerTransport serverTransport = TcpServerTransport.create("localhost", randomPort); + server = + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .fragment(frameSize) + .bind(serverTransport) + .block(); + } + + private RSocket buildClient(int frameSize) { + return RSocketConnector.create() + .fragment(frameSize) + .connect(TcpClientTransport.create(server.address())) + .block(); + } + + @AfterEach + public void cleanup() { + server.dispose(); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentNoMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + System.out.println( + "-------------------------------------------------testFragmentNoMetaData-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentRequestMetaDataOnly(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + System.out.println( + "-------------------------------------------------testFragmentRequestMetaDataOnly-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message, metaData)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentBothMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + Payload responsePayload = DefaultPayload.create(responseMessage); + System.out.println( + "-------------------------------------------------testFragmentBothMetaData-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage, metaData)); + } + + @Override + public Mono requestResponse(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Mono.just(DefaultPayload.create(responseMessage, metaData)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message, metaData)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java new file mode 100644 index 000000000..f05713215 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java @@ -0,0 +1,190 @@ +package io.rsocket.integration; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; +import reactor.util.retry.RetryBackoffSpec; + +/** + * Test case that reproduces the following GitHub Issue + */ +public class KeepaliveTest { + + private static final Logger LOG = LoggerFactory.getLogger(KeepaliveTest.class); + private static final int PORT = 23200; + + private CloseableChannel server; + + @BeforeEach + void setUp() { + server = createServer().block(); + } + + @AfterEach + void tearDown() { + server.dispose(); + server.onClose().block(); + } + + @Test + void keepAliveTest() { + RSocketClient rsocketClient = createClient(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketClient + .requestResponse(Mono.just(DefaultPayload.create(""))) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete())) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + @Test + void keepAliveTestLazy() { + Mono rsocketMono = createClientLazy(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketMono.flatMap( + rsocket -> + rsocket + .requestResponse(DefaultPayload.create("")) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete()))) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + private static Mono createServer() { + LOG.info("Starting server at port {}", PORT); + + TcpServer tcpServer = TcpServer.create().host("localhost").port(PORT); + + return RSocketServer.create( + (setupPayload, rSocket) -> { + rSocket + .onClose() + .doFirst(() -> LOG.info("Connected on server side.")) + .doOnTerminate(() -> LOG.info("Connection closed on server side.")) + .subscribe(); + + return Mono.just(new MyServerRsocket()); + }) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(tcpServer)) + .doOnNext(closeableChannel -> LOG.info("RSocket server started.")); + } + + private static RSocketClient createClient() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + Mono rsocketMono = + RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + + RSocketClient client = RSocketClient.from(rsocketMono); + + client + .source() + .doOnNext(r -> LOG.info("Got RSocket")) + .flatMap(RSocket::onClose) + .doOnError(err -> LOG.error("Error during onClose.", err)) + .retryWhen(reconnectSpec.apply("client-close")) + .doFirst(() -> LOG.info("Connected on client side.")) + .doOnTerminate(() -> LOG.info("Connection closed on client side.")) + .repeat() + .subscribe(); + + return client; + } + + private static Mono createClientLazy() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + return RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + } + + public static class MyServerRsocket implements RSocket { + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just("Pong").map(DefaultPayload::create); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/NettyUriTransportRegistryTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/NettyUriTransportRegistryTest.java deleted file mode 100644 index 209236420..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/NettyUriTransportRegistryTest.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import static org.junit.Assert.assertTrue; - -import io.rsocket.transport.ClientTransport; -import io.rsocket.transport.ServerTransport; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.TcpServerTransport; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import io.rsocket.uri.UriTransportRegistry; -import org.junit.Test; - -public class NettyUriTransportRegistryTest { - @Test - public void testTcpClient() { - ClientTransport transport = UriTransportRegistry.clientForUri("tcp://localhost:9898"); - - assertTrue(transport instanceof TcpClientTransport); - } - - @Test - public void testTcpServer() { - ServerTransport transport = UriTransportRegistry.serverForUri("tcp://localhost:9898"); - - assertTrue(transport instanceof TcpServerTransport); - } - - @Test - public void testWsClient() { - ClientTransport transport = UriTransportRegistry.clientForUri("ws://localhost:9898"); - - assertTrue(transport instanceof WebsocketClientTransport); - } - - @Test - public void testWsServer() { - ServerTransport transport = UriTransportRegistry.serverForUri("ws://localhost:9898"); - - assertTrue(transport instanceof WebsocketServerTransport); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java new file mode 100644 index 000000000..b9c0d4f60 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java @@ -0,0 +1,80 @@ +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.time.Duration; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +class RSocketFactoryNettyTransportFragmentationTest { + + static Stream> arguments() { + return Stream.of(TcpServerTransport.create(0), WebsocketServerTransport.create(0)); + } + + @ParameterizedTest + @MethodSource("arguments") + void serverSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + Mono server = + RSocketServer.create(mockAcceptor()) + .fragment(100) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void serverSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + Mono server = + RSocketServer.create(mockAcceptor()) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void clientSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + CloseableChannel server = + RSocketServer.create(mockAcceptor()).fragment(100).bind(serverTransport).block(); + + Mono rSocket = + RSocketConnector.create() + .fragment(100) + .connect(TcpClientTransport.create(server.address())) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void clientSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + CloseableChannel server = RSocketServer.create(mockAcceptor()).bind(serverTransport).block(); + + Mono rSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + private SocketAcceptor mockAcceptor() { + SocketAcceptor mock = Mockito.mock(SocketAcceptor.class); + Mockito.when(mock.accept(Mockito.any(), Mockito.any())) + .thenReturn(Mono.just(Mockito.mock(RSocket.class))); + return mock; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SecureWebsocketClientSetupRule.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SecureWebsocketClientSetupRule.java deleted file mode 100644 index c4f160eb8..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SecureWebsocketClientSetupRule.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.netty.handler.ssl.util.InsecureTrustManagerFactory; -import io.rsocket.test.ClientSetupRule; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.NettyContextCloseable; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import java.net.InetSocketAddress; -import reactor.ipc.netty.http.client.HttpClient; -import reactor.ipc.netty.http.server.HttpServer; - -public class SecureWebsocketClientSetupRule - extends ClientSetupRule { - - public SecureWebsocketClientSetupRule() { - super( - () -> new InetSocketAddress("localhost", 0), - (address, server) -> - WebsocketClientTransport.create( - HttpClient.create( - options -> - options - .connectAddress(server::address) - .sslSupport(c -> c.trustManager(InsecureTrustManagerFactory.INSTANCE))), - "https://" - + server.address().getHostName() - + ":" - + server.address().getPort() - + "/"), - address -> - WebsocketServerTransport.create( - HttpServer.create(options -> options.listenAddress(address).sslSelfSigned()))); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java new file mode 100644 index 000000000..76c352768 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.transport.netty; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.DefaultPayload; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Stream; +import org.junit.jupiter.params.provider.Arguments; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +public class SetupRejectionTest { + + /* + TODO Fix this test + @DisplayName( + "Rejecting setup by server causes requester RSocket disposal and RejectedSetupException") + @ParameterizedTest + @MethodSource(value = "transports")*/ + void rejectSetupTcp( + Function> serverTransport, + Function clientTransport) { + + String errorMessage = "error"; + RejectingAcceptor acceptor = new RejectingAcceptor(errorMessage); + Mono serverRequester = acceptor.requesterRSocket(); + + CloseableChannel channel = + RSocketServer.create(acceptor) + .bind(serverTransport.apply(new InetSocketAddress("localhost", 0))) + .block(Duration.ofSeconds(5)); + + ErrorConsumer errorConsumer = new ErrorConsumer(); + + RSocket clientRequester = + RSocketConnector.connectWith(clientTransport.apply(channel.address())) + .doOnError(errorConsumer) + .block(Duration.ofSeconds(5)); + + StepVerifier.create(errorConsumer.errors().next()) + .expectNextMatches( + err -> err instanceof RejectedSetupException && errorMessage.equals(err.getMessage())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + StepVerifier.create(clientRequester.onClose()).expectComplete().verify(Duration.ofSeconds(5)); + + StepVerifier.create(serverRequester.flatMap(socket -> socket.onClose())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + StepVerifier.create(clientRequester.requestResponse(DefaultPayload.create("test"))) + .expectErrorMatches( + err -> err instanceof RejectedSetupException && errorMessage.equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + + channel.dispose(); + } + + static Stream transports() { + Function> tcpServer = + TcpServerTransport::create; + Function> wsServer = + WebsocketServerTransport::create; + Function tcpClient = TcpClientTransport::create; + Function wsClient = WebsocketClientTransport::create; + + return Stream.of(Arguments.of(tcpServer, tcpClient), Arguments.of(wsServer, wsClient)); + } + + static class ErrorConsumer implements Consumer { + private final Sinks.Many errors = Sinks.many().multicast().onBackpressureBuffer(); + + @Override + public void accept(Throwable t) { + errors.tryEmitNext(t); + } + + Flux errors() { + return errors.asFlux(); + } + } + + private static class RejectingAcceptor implements SocketAcceptor { + private final String msg; + private final Sinks.Many requesters = Sinks.many().multicast().onBackpressureBuffer(); + + public RejectingAcceptor(String msg) { + this.msg = msg; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + requesters.tryEmitNext(sendingSocket); + return Mono.error(new RuntimeException(msg)); + } + + public Mono requesterRSocket() { + return requesters.asFlux().next(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpClientServerTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpClientServerTest.java deleted file mode 100644 index 475dabf0c..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpClientServerTest.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.transport.netty; - -import io.rsocket.test.BaseClientServerTest; - -public class TcpClientServerTest extends BaseClientServerTest { - @Override - protected TcpClientSetupRule createClientServer() { - return new TcpClientSetupRule(); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpClientSetupRule.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpClientSetupRule.java deleted file mode 100644 index de3c00850..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpClientSetupRule.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.test.ClientSetupRule; -import io.rsocket.transport.netty.client.TcpClientTransport; -import io.rsocket.transport.netty.server.NettyContextCloseable; -import io.rsocket.transport.netty.server.TcpServerTransport; -import java.net.InetSocketAddress; - -public class TcpClientSetupRule extends ClientSetupRule { - - public TcpClientSetupRule() { - super( - () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server) -> TcpClientTransport.create(server.address()), - TcpServerTransport::create); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java new file mode 100644 index 000000000..b17da654f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(2); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java index 8663b8352..88c64648c 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,28 +13,85 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.Resume; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.test.PerfTest; import io.rsocket.test.PingClient; import io.rsocket.transport.netty.client.TcpClientTransport; import java.time.Duration; import org.HdrHistogram.Recorder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import reactor.core.publisher.Mono; +@PerfTest public final class TcpPing { + private static final int INTERACTIONS_COUNT = 1_000_000_000; + private static final int port = Integer.valueOf(System.getProperty("RSOCKET_TEST_PORT", "7878")); + + @BeforeEach + void setUp() { + System.out.println("Starting ping-pong test (TCP transport)"); + System.out.println("port: " + port); + } + + @Test + void requestResponseTest() { + PingClient pingClient = newPingClient(); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); + + pingClient + .requestResponsePingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) + .blockLast(); + } + + @Test + void requestStreamTest() { + PingClient pingClient = newPingClient(); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); - public static void main(String... args) { - Mono client = - RSocketFactory.connect().transport(TcpClientTransport.create(7878)).start(); + pingClient + .requestStreamPingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) + .blockLast(); + } - PingClient pingClient = new PingClient(client); + @Test + void requestStreamResumableTest() { + PingClient pingClient = newResumablePingClient(); Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); - final int count = 1_000_000_000; + pingClient - .startPingPong(count, recorder) - .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) + .requestStreamPingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) .blockLast(); } + + private static PingClient newPingClient() { + return newPingClient(false); + } + + private static PingClient newResumablePingClient() { + return newPingClient(true); + } + + private static PingClient newPingClient(boolean isResumable) { + RSocketConnector connector = RSocketConnector.create(); + if (isResumable) { + connector.resume(new Resume()); + } + Mono rSocket = + connector + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMinutes(1), Duration.ofMinutes(30)) + .connect(TcpClientTransport.create(port)); + + return new PingClient(rSocket); + } } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java index 5afbff908..338868470 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.TcpServerTransport; public final class TcpPongServer { + private static final boolean isResume = + Boolean.valueOf(System.getProperty("RSOCKET_TEST_RESUME", "false")); + private static final int port = Integer.valueOf(System.getProperty("RSOCKET_TEST_PORT", "7878")); public static void main(String... args) { - RSocketFactory.receive() - .acceptor(new PingHandler()) - .transport(TcpServerTransport.create(7878)) - .start() + System.out.println("Starting TCP ping-pong server"); + System.out.println("port: " + port); + System.out.println("resume enabled: " + isResume); + + RSocketServer server = RSocketServer.create(new PingHandler()); + if (isResume) { + server.resume(new Resume()); + } + server + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create("localhost", port)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java new file mode 100644 index 000000000..7be1c1c54 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..39b3cec67 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java new file mode 100644 index 000000000..ee49b83cd --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.core.Exceptions; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +public class TcpSecureTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE)))), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + TcpServer server = + TcpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return TcpServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(10); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java new file mode 100644 index 000000000..428681f3e --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(2); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java new file mode 100644 index 000000000..2deb4a4a8 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java @@ -0,0 +1,128 @@ +package io.rsocket.transport.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.*; +import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.URI; + +/** + * This is an example of a WebSocket client. + * + *

In order to run this example you need a compatible WebSocket server. Therefore you can either + * start the WebSocket server from the examples or connect to an existing WebSocket server such as + * ws://echo.websocket.org. + * + *

The client will attempt to connect to the URI passed to it as the first argument. You don't + * have to specify any arguments if you want to connect to the example WebSocket server, as this is + * the default. + */ +public final class WebSocketClient { + + static final String URL = System.getProperty("url", "ws://127.0.0.1:7878/websocket"); + + public static void main(String[] args) throws Exception { + URI uri = new URI(URL); + String scheme = uri.getScheme() == null ? "ws" : uri.getScheme(); + final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); + final int port; + if (uri.getPort() == -1) { + if ("ws".equalsIgnoreCase(scheme)) { + port = 80; + } else if ("wss".equalsIgnoreCase(scheme)) { + port = 443; + } else { + port = -1; + } + } else { + port = uri.getPort(); + } + + if (!"ws".equalsIgnoreCase(scheme) && !"wss".equalsIgnoreCase(scheme)) { + System.err.println("Only WS(S) is supported."); + return; + } + + final boolean ssl = "wss".equalsIgnoreCase(scheme); + final SslContext sslCtx; + if (ssl) { + sslCtx = + SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + } else { + sslCtx = null; + } + + EventLoopGroup group = new NioEventLoopGroup(); + try { + // Connect with V13 (RFC 6455 aka HyBi-17). You can change it to V08 or V00. + // If you change it to V00, ping is not supported and remember to change + // HttpResponseDecoder to WebSocketHttpResponseDecoder in the pipeline. + final WebSocketClientHandler handler = + new WebSocketClientHandler( + WebSocketClientHandshakerFactory.newHandshaker( + uri, WebSocketVersion.V13, null, true, new DefaultHttpHeaders())); + + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline p = ch.pipeline(); + if (sslCtx != null) { + p.addLast(sslCtx.newHandler(ch.alloc(), host, port)); + } + p.addLast( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + WebSocketClientCompressionHandler.INSTANCE, + handler); + } + }); + + Channel ch = b.connect(uri.getHost(), port).sync().channel(); + handler.handshakeFuture().sync(); + + BufferedReader console = new BufferedReader(new InputStreamReader(System.in)); + while (true) { + String msg = console.readLine(); + if (msg == null) { + break; + } else if ("bye".equals(msg.toLowerCase())) { + ch.writeAndFlush(new CloseWebSocketFrame()); + ch.closeFuture().sync(); + break; + } else if ("ping".equals(msg.toLowerCase())) { + WebSocketFrame frame = + new PingWebSocketFrame(Unpooled.wrappedBuffer(new byte[] {8, 1, 8, 1})); + ch.writeAndFlush(frame); + } else if ("pong".equals(msg.toLowerCase())) { + WebSocketFrame frame = + new PongWebSocketFrame(Unpooled.wrappedBuffer(new byte[] {8, 1, 8, 1})); + ch.writeAndFlush(frame); + } else { + WebSocketFrame frame = new TextWebSocketFrame(msg); + ch.writeAndFlush(frame); + } + } + } finally { + group.shutdownGracefully(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java new file mode 100644 index 000000000..092cad2c7 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java @@ -0,0 +1,90 @@ +package io.rsocket.transport.netty; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.util.CharsetUtil; + +public class WebSocketClientHandler extends SimpleChannelInboundHandler { + + private final WebSocketClientHandshaker handshaker; + private ChannelPromise handshakeFuture; + + public WebSocketClientHandler(WebSocketClientHandshaker handshaker) { + this.handshaker = handshaker; + } + + public ChannelFuture handshakeFuture() { + return handshakeFuture; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handshakeFuture = ctx.newPromise(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + handshaker.handshake(ctx.channel()); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + System.out.println("WebSocket Client disconnected!"); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + Channel ch = ctx.channel(); + if (!handshaker.isHandshakeComplete()) { + try { + handshaker.finishHandshake(ch, (FullHttpResponse) msg); + System.out.println("WebSocket Client connected!"); + handshakeFuture.setSuccess(); + } catch (WebSocketHandshakeException e) { + System.out.println("WebSocket Client failed to connect"); + handshakeFuture.setFailure(e); + } + return; + } + + if (msg instanceof FullHttpResponse) { + FullHttpResponse response = (FullHttpResponse) msg; + throw new IllegalStateException( + "Unexpected FullHttpResponse (getStatus=" + + response.status() + + ", content=" + + response.content().toString(CharsetUtil.UTF_8) + + ')'); + } + + WebSocketFrame frame = (WebSocketFrame) msg; + if (frame instanceof TextWebSocketFrame) { + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + System.out.println("WebSocket Client received message: " + textFrame.text()); + } else if (frame instanceof PongWebSocketFrame) { + System.out.println("WebSocket Client received pong"); + } else if (frame instanceof CloseWebSocketFrame) { + System.out.println("WebSocket Client received closing"); + ch.close(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + if (!handshakeFuture.isDone()) { + handshakeFuture.setFailure(cause); + } + ctx.close(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java new file mode 100644 index 000000000..c418dea0f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java @@ -0,0 +1,49 @@ +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketRouteTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.net.URI; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +public class WebSocketTransportIntegrationTest { + + @Test + public void sendStreamOfDataWithExternalHttpServerTest() { + ServerTransport.ConnectionAcceptor acceptor = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(0, 10).map(i -> DefaultPayload.create(String.valueOf(i))))) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .route(router -> router.ws("/test", WebsocketRouteTransport.newHandler(acceptor))) + .bindNow(); + + RSocket rsocket = + RSocketConnector.connectWith( + WebsocketClientTransport.create( + URI.create("ws://" + server.host() + ":" + server.port() + "/test"))) + .block(); + + StepVerifier.create(rsocket.requestStream(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectNextCount(10) + .expectComplete() + .verify(Duration.ofMillis(1000)); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketClientServerTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketClientServerTest.java deleted file mode 100644 index 91542049d..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketClientServerTest.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.rsocket.transport.netty; - -import io.rsocket.test.BaseClientServerTest; - -public class WebsocketClientServerTest extends BaseClientServerTest { - @Override - protected WebsocketClientSetupRule createClientServer() { - return new WebsocketClientSetupRule(); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketClientSetupRule.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketClientSetupRule.java deleted file mode 100644 index a8c90872e..000000000 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketClientSetupRule.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2016 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.transport.netty; - -import io.rsocket.test.ClientSetupRule; -import io.rsocket.transport.netty.client.WebsocketClientTransport; -import io.rsocket.transport.netty.server.NettyContextCloseable; -import io.rsocket.transport.netty.server.WebsocketServerTransport; -import java.net.InetSocketAddress; - -public class WebsocketClientSetupRule - extends ClientSetupRule { - - public WebsocketClientSetupRule() { - super( - () -> InetSocketAddress.createUnresolved("localhost", 0), - (address, server) -> WebsocketClientTransport.create(server.address()), - address -> WebsocketServerTransport.create(address.getHostName(), address.getPort())); - } -} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java index ab4118cbb..a784a43c0 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.transport.netty; import io.rsocket.RSocket; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketConnector; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingClient; import io.rsocket.transport.netty.client.WebsocketClientTransport; import java.time.Duration; @@ -27,13 +29,18 @@ public final class WebsocketPing { public static void main(String... args) { Mono client = - RSocketFactory.connect().transport(WebsocketClientTransport.create(7878)).start(); + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(WebsocketClientTransport.create(7878)); PingClient pingClient = new PingClient(client); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); - final int count = 1_000_000_000; + + int count = 1_000_000_000; + pingClient - .startPingPong(count, recorder) + .requestResponsePingPong(count, recorder) .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) .blockLast(); } diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java new file mode 100644 index 000000000..ff0fa75b4 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -0,0 +1,168 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.transport.netty; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketRouteTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +public class WebsocketPingPongIntegrationTest { + private static final String host = "localhost"; + private static final int port = 8088; + + private Closeable server; + + @AfterEach + void tearDown() { + server.dispose(); + } + + @ParameterizedTest + @MethodSource("provideServerTransport") + void webSocketPingPong(ServerTransport serverTransport) { + server = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .bind(serverTransport) + .block(); + + String expectedData = "data"; + String expectedPing = "ping"; + + PingSender pingSender = new PingSender(); + + HttpClient httpClient = + HttpClient.create() + .tcpConfiguration( + tcpClient -> + tcpClient + .doOnConnected(b -> b.addHandlerLast(pingSender)) + .host(host) + .port(port)); + + RSocket rSocket = + RSocketConnector.connectWith(WebsocketClientTransport.create(httpClient, "/")).block(); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPing(expectedPing)) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + pingSender + .receivePong() + .as(StepVerifier::create) + .expectNextMatches(expectedPing::equals) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPong()) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + private static Stream provideServerTransport() { + return Stream.of( + Arguments.of(WebsocketServerTransport.create(host, port)), + Arguments.of( + new WebsocketRouteTransport( + HttpServer.create().host(host).port(port), routes -> {}, "/"))); + } + + private static class PingSender extends ChannelInboundHandlerAdapter { + private final Sinks.One channel = Sinks.one(); + private final Sinks.One pong = Sinks.one(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof PongWebSocketFrame) { + pong.tryEmitValue(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (!(channel.scan(Scannable.Attr.TERMINATED)) && ch.isWritable()) { + channel.tryEmitValue(ctx.channel()); + } + super.channelWritabilityChanged(ctx); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (ch.isWritable()) { + channel.tryEmitValue(ch); + } + super.handlerAdded(ctx); + } + + public Mono sendPing(String data) { + return send( + new PingWebSocketFrame(Unpooled.wrappedBuffer(data.getBytes(StandardCharsets.UTF_8)))); + } + + public Mono sendPong() { + return send(new PongWebSocketFrame()); + } + + public Mono receivePong() { + return pong.asMono(); + } + + private Mono send(WebSocketFrame webSocketFrame) { + return channel.asMono().doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java index 7617efe60..84dc816be 100644 --- a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,19 +13,20 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package io.rsocket.transport.netty; -import io.rsocket.RSocketFactory; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; import io.rsocket.test.PingHandler; import io.rsocket.transport.netty.server.WebsocketServerTransport; public final class WebsocketPongServer { public static void main(String... args) { - RSocketFactory.receive() - .acceptor(new PingHandler()) - .transport(WebsocketServerTransport.create(7878)) - .start() + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(WebsocketServerTransport.create(7878)) .block() .onClose() .block(); diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java new file mode 100644 index 000000000..043f6bc64 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..b1ca65fcc --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java new file mode 100644 index 000000000..81f7ffb95 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.core.Exceptions; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketSecureTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE))), + String.format( + "https://%s:%d/", + server.address().getHostName(), server.address().getPort())), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + HttpServer server = + HttpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return WebsocketServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(5); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java new file mode 100644 index 000000000..cdd507456 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java new file mode 100644 index 000000000..ac4c6044b --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.test.StepVerifier; + +final class TcpClientTransportTest { + + @DisplayName("connects to server") + @Test + void connect() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + TcpServerTransport serverTransport = TcpServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .flatMap(context -> TcpClientTransport.create(context.address()).connect()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("create generates error if server not started") + @Test + void connectNoServer() { + TcpClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + } + + @DisplayName("creates client with BindAddress") + @Test + void createBindAddress() { + assertThat(TcpClientTransport.create("test-bind-address", 8000)).isNotNull(); + } + + @DisplayName("creates client with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + TcpClientTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull(); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpClientTransport.create((String) null, 8000)) + .withMessage("bindAddress must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpClientTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("create throws NullPointerException with null client") + @Test + void createNullTcpClient() { + assertThatNullPointerException() + .isThrownBy(() -> TcpClientTransport.create((TcpClient) null)) + .withMessage("client must not be null"); + } + + @DisplayName("creates client with port") + @Test + void createPort() { + assertThat(TcpClientTransport.create(8000)).isNotNull(); + } + + @DisplayName("creates client with TcpClient") + @Test + void createTcpClient() { + assertThat(TcpClientTransport.create(TcpClient.create())).isNotNull(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java new file mode 100644 index 000000000..2a3670251 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.net.URI; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; +import reactor.test.StepVerifier; + +@ExtendWith(MockitoExtension.class) +final class WebsocketClientTransportTest { + + @DisplayName("connects to server") + @Test + void connect() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .flatMap(context -> WebsocketClientTransport.create(context.address()).connect()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("create generates error if server not started") + @Test + void connectNoServer() { + WebsocketClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + } + + @DisplayName("creates client with BindAddress") + @Test + void createBindAddress() { + assertThat(WebsocketClientTransport.create("test-bind-address", 8000)) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with HttpClient") + @Test + void createHttpClient() { + assertThat(WebsocketClientTransport.create(HttpClient.create(), "/")) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with HttpClient and path without root") + @Test + void createHttpClientWithPathWithoutRoot() { + assertThat(WebsocketClientTransport.create(HttpClient.create(), "test")) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/test"); + } + + @DisplayName("creates client with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + WebsocketClientTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create(null, 8000)) + .withMessage("host"); + } + + @DisplayName("create throws NullPointerException with null client") + @Test + void createNullHttpClient() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create(null, "/test-path")) + .withMessage("HttpClient must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("create throws NullPointerException with null path") + @Test + void createNullPath() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create(HttpClient.create(), null)) + .withMessage("path must not be null"); + } + + @DisplayName("create throws NullPointerException with null URI") + @Test + void createNullUri() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create((URI) null)) + .withMessage("uri must not be null"); + } + + @DisplayName("creates client with port") + @Test + void createPort() { + assertThat(WebsocketClientTransport.create(8000)).isNotNull(); + } + + @DisplayName("creates client with URI") + @Test + void createUri() { + assertThat(WebsocketClientTransport.create(URI.create("ws://test-host"))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with URI path") + @Test + void createUriPath() { + assertThat(WebsocketClientTransport.create(URI.create("ws://test-host/test"))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/test"); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java new file mode 100644 index 000000000..bd53a9b3f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableChannel; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; + +final class CloseableChannelTest { + + private final Mono channel = + TcpServer.create().handle((in, out) -> Mono.empty()).bind(); + + @DisplayName("returns the address of the context") + @Test + void address() { + channel + .map(CloseableChannel::new) + .map(CloseableChannel::address) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("creates instance") + @Test + void constructor() { + channel.map(CloseableChannel::new).as(StepVerifier::create).expectNextCount(1).verifyComplete(); + } + + @DisplayName("constructor throws NullPointerException with null context") + @Test + void constructorNullContext() { + assertThatNullPointerException() + .isThrownBy(() -> new CloseableChannel(null)) + .withMessage("channel must not be null"); + } + + @DisplayName("disposes context") + @Test + void dispose() { + channel + .map(CloseableChannel::new) + .delayUntil( + closeable -> { + closeable.dispose(); + return closeable.onClose().log(); + }) + .as(StepVerifier::create) + .assertNext(closeable -> assertThat(closeable.isDisposed()).isTrue()) + .verifyComplete(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java new file mode 100644 index 000000000..0e14d8f1d --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import java.net.InetSocketAddress; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; + +final class TcpServerTransportTest { + + @DisplayName("creates server with BindAddress") + @Test + void createBindAddress() { + assertThat(TcpServerTransport.create("test-bind-address", 8000)).isNotNull(); + } + + @DisplayName("creates server with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + TcpServerTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull(); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create((String) null, 8000)) + .withMessage("bindAddress must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("create throws NullPointerException with null server") + @Test + void createNullTcpClient() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create((TcpServer) null)) + .withMessage("server must not be null"); + } + + @DisplayName("creates server with port") + @Test + void createPort() { + assertThat(TcpServerTransport.create("localhost", 8000)).isNotNull(); + } + + @DisplayName("creates client with TcpServer") + @Test + void createTcpClient() { + assertThat(TcpServerTransport.create(TcpServer.create())).isNotNull(); + } + + @DisplayName("starts server") + @Test + void start() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + TcpServerTransport serverTransport = TcpServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("start throws NullPointerException with null acceptor") + @Test + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create("localhost", 8000).start(null)) + .withMessage("acceptor must not be null"); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java new file mode 100644 index 000000000..2670b4a4b --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +final class WebsocketRouteTransportTest { + + @DisplayName("creates server") + @Test + void constructor() { + new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path"); + } + + @DisplayName("constructor throw NullPointer with null path") + @Test + void constructorNullPath() { + assertThatNullPointerException() + .isThrownBy(() -> new WebsocketRouteTransport(HttpServer.create(), routes -> {}, null)) + .withMessage("path must not be null"); + } + + @DisplayName("constructor throw NullPointer with null routesBuilder") + @Test + void constructorNullRoutesBuilder() { + assertThatNullPointerException() + .isThrownBy(() -> new WebsocketRouteTransport(HttpServer.create(), null, "/test-path")) + .withMessage("routesBuilder must not be null"); + } + + @DisplayName("constructor throw NullPointer with null server") + @Test + void constructorNullServer() { + assertThatNullPointerException() + .isThrownBy(() -> new WebsocketRouteTransport(null, routes -> {}, "/test-path")) + .withMessage("server must not be null"); + } + + @DisplayName("starts server") + @Test + void start() { + WebsocketRouteTransport serverTransport = + new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path"); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("start throw NullPointerException with null acceptor") + @Test + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy( + () -> + new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path") + .start(null)) + .withMessage("acceptor must not be null"); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java new file mode 100644 index 000000000..540076704 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -0,0 +1,137 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.ArgumentMatchers.any; + +import java.net.InetSocketAddress; +import java.util.function.BiFunction; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; +import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.HttpServerRequest; +import reactor.netty.http.server.HttpServerResponse; +import reactor.netty.http.server.WebsocketServerSpec; +import reactor.test.StepVerifier; + +final class WebsocketServerTransportTest { + + @Test + public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() { + ArgumentCaptor httpHandlerCaptor = ArgumentCaptor.forClass(BiFunction.class); + HttpServer server = Mockito.spy(HttpServer.create()); + Mockito.doAnswer(a -> server).when(server).handle(httpHandlerCaptor.capture()); + Mockito.doAnswer(a -> server).when(server).doOnConnection(any()); + Mockito.doAnswer(a -> Mono.empty()).when(server).bind(); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(server); + serverTransport.start(c -> Mono.empty()).subscribe(); + + HttpServerRequest httpServerRequest = Mockito.mock(HttpServerRequest.class); + HttpServerResponse httpServerResponse = Mockito.mock(HttpServerResponse.class); + + httpHandlerCaptor.getValue().apply(httpServerRequest, httpServerResponse); + + ArgumentCaptor handlerCaptor = ArgumentCaptor.forClass(BiFunction.class); + ArgumentCaptor specCaptor = + ArgumentCaptor.forClass(WebsocketServerSpec.class); + + Mockito.verify(httpServerResponse).sendWebsocket(handlerCaptor.capture(), specCaptor.capture()); + + WebsocketServerSpec spec = specCaptor.getValue(); + assertThat(spec.maxFramePayloadLength()).isEqualTo(FRAME_LENGTH_MASK); + } + + @DisplayName("creates server with BindAddress") + @Test + void createBindAddress() { + assertThat(WebsocketServerTransport.create("test-bind-address", 8000)).isNotNull(); + } + + @DisplayName("creates server with HttpClient") + @Test + void createHttpClient() { + assertThat(WebsocketServerTransport.create(HttpServer.create())).isNotNull(); + } + + @DisplayName("creates server with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + WebsocketServerTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull(); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create(null, 8000)) + .withMessage("bindAddress must not be null"); + } + + @DisplayName("create throws NullPointerException with null client") + @Test + void createNullHttpClient() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create((HttpServer) null)) + .withMessage("server must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("creates server with port") + @Test + void createPort() { + assertThat(WebsocketServerTransport.create(8000)).isNotNull(); + } + + @DisplayName("starts server") + @Test + void start() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("start throws NullPointerException with null acceptor") + @Test + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create(8000).start(null)) + .withMessage("acceptor must not be null"); + } +} diff --git a/rsocket-transport-netty/src/test/resources/log4j.properties b/rsocket-transport-netty/src/test/resources/log4j.properties deleted file mode 100644 index 04bbed8ae..000000000 --- a/rsocket-transport-netty/src/test/resources/log4j.properties +++ /dev/null @@ -1,18 +0,0 @@ -# -# Copyright 2016 Netflix, Inc. -#

-# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -#

-# http://www.apache.org/licenses/LICENSE-2.0 -#

-# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on -# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the -# specific language governing permissions and limitations under the License. -# -log4j.rootLogger=INFO, stdout - -log4j.appender.stdout=org.apache.log4j.ConsoleAppender -log4j.appender.stdout.layout=org.apache.log4j.PatternLayout -log4j.appender.stdout.layout.ConversionPattern=%d{HH:mm:ss,SSS} %5p [%t] (%F) - %m%n -log4j.logger.io.rsocketogger=Debug \ No newline at end of file diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml new file mode 100644 index 000000000..981d6d0b6 --- /dev/null +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -0,0 +1,42 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + + + + + + + + + + + + + + + + diff --git a/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 000000000..ca6ee9cea --- /dev/null +++ b/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 2509c250c..25c3feee5 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1,11 +1,11 @@ /* - * Copyright 2016 Netflix, Inc. + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,13 +13,29 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +plugins { + id 'com.gradle.enterprise' version '3.1' +} + +rootProject.name = 'rsocket-java' -rootProject.name='rsocket' -include 'rsocket-load-balancer' include 'rsocket-core' -include 'rsocket-examples' -include 'rsocket-spectator' +include 'rsocket-load-balancer' +include 'rsocket-micrometer' include 'rsocket-test' -include 'rsocket-transport-aeron' include 'rsocket-transport-local' include 'rsocket-transport-netty' +include 'rsocket-bom' + +include 'rsocket-examples' +include 'benchmarks' + + + +gradleEnterprise { + buildScan { + termsOfServiceUrl = 'https://gradle.com/terms-of-service' + termsOfServiceAgree = 'yes' + } +} +