diff --git a/nemo_text_processing/text_normalization/ko/taggers/date.py b/nemo_text_processing/text_normalization/ko/taggers/date.py index 4f2da5702..45943e1a3 100644 --- a/nemo_text_processing/text_normalization/ko/taggers/date.py +++ b/nemo_text_processing/text_normalization/ko/taggers/date.py @@ -249,6 +249,17 @@ def __init__(self, cardinal: GraphFst, deterministic: bool = True): + pynutil.insert("\"") ) + month_josa = pynini.union("에", "은", "는", "에는").optimize() + + individual_month_component_with_josa = ( + pynutil.insert('month: "') + + month_cardinal + + pynutil.delete("월") + + pynutil.insert("월") + + pynini.closure(month_josa, 0, 1) + + pynutil.insert('"') + ).optimize() + individual_day_component = ( pynutil.insert("day: \"") + cardinal_lz @@ -272,6 +283,7 @@ def __init__(self, cardinal: GraphFst, deterministic: bool = True): day_and_weekday_component | month_and_weekday_component | individual_year_component + | individual_month_component_with_josa | individual_month_component | individual_day_component | week_component diff --git a/nemo_text_processing/text_normalization/ko/verbalizers/post_processing.py b/nemo_text_processing/text_normalization/ko/verbalizers/post_processing.py new file mode 100644 index 000000000..45fcb259f --- /dev/null +++ b/nemo_text_processing/text_normalization/ko/verbalizers/post_processing.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pynini +from pynini.lib import pynutil + +from nemo_text_processing.text_normalization.ko.graph_utils import NEMO_SIGMA, NEMO_SPACE, generator_main +from nemo_text_processing.utils.logging import logger + + +class PostProcessingFst: + def __init__(self, cache_dir: str = None, overwrite_cache: bool = False): + far_file = None + if cache_dir is not None and cache_dir != "None": + os.makedirs(cache_dir, exist_ok=True) + far_file = os.path.join(cache_dir, "ko_tn_post_processing.far") + + if not overwrite_cache and far_file and os.path.exists(far_file): + self.fst = pynini.Far(far_file, mode="r")["post_process_graph"] + logger.info(f"Post processing graph was restored from {far_file}.") + else: + self.fst = self.get_postprocess_graph() + if far_file: + generator_main(far_file, {"post_process_graph": self.fst}) + + def get_postprocess_graph(self): + delete_space = pynutil.delete(NEMO_SPACE) + + vowel_final = pynini.union("아", "야", "어", "여", "오", "요", "우", "유", "이", "애", "에", "사", "오", "구") + + rule_i_to_ga = pynini.cdrewrite( + delete_space + pynini.cross("이 ", "가 "), + vowel_final, + "", + NEMO_SIGMA, + ) + + rule_eun_to_neun = pynini.cdrewrite( + delete_space + pynini.cross("은 ", "는 "), + vowel_final, + "", + NEMO_SIGMA, + ) + + rule_eul_to_reul = pynini.cdrewrite( + delete_space + pynini.cross("을 ", "를 "), + vowel_final, + "", + NEMO_SIGMA, + ) + + graph = rule_i_to_ga @ rule_eun_to_neun @ rule_eul_to_reul + return graph.optimize() diff --git a/nemo_text_processing/text_normalization/normalize.py b/nemo_text_processing/text_normalization/normalize.py index 5e2f9ebb5..12661bd3d 100644 --- a/nemo_text_processing/text_normalization/normalize.py +++ b/nemo_text_processing/text_normalization/normalize.py @@ -187,7 +187,11 @@ def __init__( self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache) elif lang == 'ko': from nemo_text_processing.text_normalization.ko.taggers.tokenize_and_classify import ClassifyFst + from nemo_text_processing.text_normalization.ko.verbalizers.post_processing import PostProcessingFst from nemo_text_processing.text_normalization.ko.verbalizers.verbalize_final import VerbalizeFinalFst + + if post_process: + self.post_processor = PostProcessingFst(cache_dir=cache_dir, overwrite_cache=overwrite_cache) else: raise NotImplementedError(f"Language {lang} has not been supported yet.") @@ -388,7 +392,11 @@ def normalize( return text output = SPACE_DUP.sub(' ', output[1:]) - if self.lang in ["en", "hi", "vi"] and hasattr(self, 'post_processor') and self.post_processor is not None: + if ( + self.lang in ["en", "hi", "vi", "ko"] + and hasattr(self, 'post_processor') + and self.post_processor is not None + ): output = self.post_process(output) if punct_post_process: diff --git a/tests/nemo_text_processing/ko/data_text_normalization/test_cases_fraction.txt b/tests/nemo_text_processing/ko/data_text_normalization/test_cases_fraction.txt index a183be59b..fc39fd495 100644 --- a/tests/nemo_text_processing/ko/data_text_normalization/test_cases_fraction.txt +++ b/tests/nemo_text_processing/ko/data_text_normalization/test_cases_fraction.txt @@ -11,4 +11,18 @@ 1과1/3~일과 삼분의 일 1과√1/4~일과 사분의 루트 일 3분의1~삼분의 일 -121분의3221~백이십일분의 삼천이백이십일 \ No newline at end of file +121분의3221~백이십일분의 삼천이백이십일 +이번 경기의 3/5이 중요하다~이번 경기의 오분의 삼 이 중요하다 +전체 구역의 4/7이 통제되었다~전체 구역의 칠분의 사가 통제되었다 +설문 응답자의 9/10 이 찬성했다~설문 응답자의 십분의 구가 찬성했다 +그 중 2/3은 성공했다~그 중 삼분의 이는 성공했다 +참가자의 5/8이 탈락했다~참가자의 팔분의 오가 탈락했다 +참가자의 6/7 이 통과했다~참가자의 칠분의 육 이 통과했다 +전체의 3/4 이 감소했다~전체의 사분의 삼 이 감소했다 +응답자의 2/5이 반대했다~응답자의 오분의 이가 반대했다 +학생의 7/9 이 합격했다~학생의 구분의 칠 이 합격했다 +전체의 1/2 이 남았다~전체의 이분의 일 이 남았다 +그 중 4/5이 성공했다~그 중 오분의 사가 성공했다 +전체의 5/6이 완료되었다~전체의 육분의 오가 완료되었다 +참가자의 3/8이 탈락했다~참가자의 팔분의 삼 이 탈락했다 +응답자의 6/10 이 동의했다~응답자의 십분의 육 이 동의했다 \ No newline at end of file