From 81765d02c46b9b970208991ca33808b8803f94cc Mon Sep 17 00:00:00 2001 From: NoahLan <6995syu@163.com> Date: Sat, 17 Jun 2023 13:39:58 +0800 Subject: [PATCH] first commit --- .gitignore | 21 + go.mod | 24 + go.sum | 40 ++ internal/common/cli.go | 70 ++ internal/common/cli_nonwin.go | 28 + internal/common/cli_windows.go | 28 + internal/common/env.go | 78 +++ internal/common/sys.go | 30 + internal/convert/bool.go | 38 + narr/check.go | 105 +++ narr/check_test.go | 108 +++ narr/collection.go | 427 ++++++++++++ narr/collection_test.go | 317 +++++++++ narr/convert.go | 266 +++++++ narr/convert_test.go | 152 ++++ narr/format.go | 124 ++++ narr/format_test.go | 23 + narr/types.go | 64 ++ narr/types_test.go | 54 ++ narr/util.go | 130 ++++ narr/util_test.go | 128 ++++ nbyte/buffer.go | 65 ++ nbyte/buffer_test.go | 25 + nbyte/bytex.go | 18 + nbyte/check.go | 62 ++ nbyte/encoder.go | 63 ++ nbyte/encoder_test.go | 27 + nbyte/util.go | 104 +++ nbyte/util_test.go | 55 ++ ncli/cmdline/builder.go | 84 +++ ncli/cmdline/builder_test.go | 52 ++ ncli/cmdline/cmdline.go | 41 ++ ncli/cmdline/parser.go | 173 +++++ ncli/cmdline/parser_test.go | 145 ++++ ncli/color_print.go | 110 +++ ncli/info.go | 54 ++ ncli/info_nonwin.go | 11 + ncli/info_windows.go | 10 + ncli/read.go | 142 ++++ ncli/read_nonwin.go | 3 + ncli/read_test.go | 54 ++ ncli/read_windows.go | 1 + ncli/util.go | 147 ++++ ncli/util_test.go | 144 ++++ ncrypt/aes_des.go | 411 +++++++++++ ncrypt/base64.go | 14 + ncrypt/bcrypt.go | 18 + ncrypt/hmac.go | 38 + ncrypt/md5.go | 52 ++ ncrypt/rsa.go | 128 ++++ ncrypt/sha.go | 29 + ndef/consts.go | 27 + ndef/errors.go | 6 + ndef/formatter.go | 56 ++ ndef/serializer.go | 37 + ndef/symbols.go | 41 ++ ndef/types.go | 46 ++ nenv/info.go | 169 +++++ nenv/parse.go | 25 + nfs/check.go | 136 ++++ nfs/check_test.go | 87 +++ nfs/find.go | 155 +++++ nfs/find_test.go | 95 +++ nfs/finder/README.md | 23 + nfs/finder/config.go | 492 +++++++++++++ nfs/finder/elem.go | 48 ++ nfs/finder/finder.go | 353 ++++++++++ nfs/finder/finder_test.go | 160 +++++ nfs/finder/matcher.go | 138 ++++ nfs/finder/matchers.go | 289 ++++++++ nfs/finder/matchers_test.go | 84 +++ nfs/finder/testdata/.dotdir/some.txt | 0 nfs/finder/testdata/.env | 0 nfs/finder/testdata/test.txt | 1 + nfs/fn.go | 7 + nfs/info.go | 72 ++ nfs/info_nonwin.go | 18 + nfs/info_test.go | 18 + nfs/info_windows.go | 18 + nfs/oper.go | 255 +++++++ nfs/oper_read.go | 137 ++++ nfs/oper_read_test.go | 43 ++ nfs/oper_test.go | 111 +++ nfs/oper_write.go | 100 +++ nfs/oper_write_test.go | 26 + nfs/testdata/.dotdir/some.txt | 0 nfs/testdata/.env | 0 nfs/testdata/cp-file-dst.txt | 1 + nfs/testdata/cp-file-src.txt | 1 + nfs/testdata/get-contents.txt | 1 + nfs/testdata/mimetext.txt | 1 + nfs/util.go | 155 +++++ nfs/util_nonwin.go | 16 + nfs/util_nonwin_test.go | 19 + nfs/util_test.go | 59 ++ nfs/util_windows_test.go | 19 + ngo/base_fn.go | 72 ++ ngo/codec/serializer_json.go | 25 + ngo/ext_fn.go | 23 + nlog/color.go | 22 + nlog/color_test.go | 32 + nlog/config.go | 45 ++ nlog/fields.go | 48 ++ nlog/fields_test.go | 120 ++++ nlog/lesslogger.go | 27 + nlog/lesslogger_test.go | 34 + nlog/lesswriter.go | 22 + nlog/lesswriter_test.go | 19 + nlog/limitedexecutor.go | 40 ++ nlog/limitedexecutor_test.go | 61 ++ nlog/logger.go | 50 ++ nlog/logs.go | 464 +++++++++++++ nlog/logs_test.go | 838 +++++++++++++++++++++++ nlog/logwriter.go | 22 + nlog/richlogger.go | 179 +++++ nlog/richlogger_test.go | 318 +++++++++ nlog/rotatelogger.go | 443 ++++++++++++ nlog/rotatelogger_test.go | 348 ++++++++++ nlog/syslog.go | 15 + nlog/syslog_test.go | 59 ++ nlog/util.go | 55 ++ nlog/util_test.go | 72 ++ nlog/vars.go | 66 ++ nlog/writer.go | 403 +++++++++++ nlog/writer_test.go | 221 ++++++ nmap/check.go | 63 ++ nmap/check_test.go | 33 + nmap/convert.go | 142 ++++ nmap/convert_test.go | 100 +++ nmap/data.go | 259 +++++++ nmap/data_test.go | 181 +++++ nmap/errors.go | 39 ++ nmap/format.go | 124 ++++ nmap/format_test.go | 48 ++ nmap/get.go | 169 +++++ nmap/get_test.go | 211 ++++++ nmap/setval.go | 339 +++++++++ nmap/setval_test.go | 188 +++++ nmap/smap.go | 126 ++++ nmap/smap_test.go | 61 ++ nmap/util.go | 128 ++++ nmap/util_test.go | 74 ++ nmath/check.go | 97 +++ nmath/check_test.go | 67 ++ nmath/convert.go | 404 +++++++++++ nmath/convert_test.go | 164 +++++ nmath/number.go | 14 + nmath/number_test.go | 36 + nmath/util.go | 83 +++ nmath/util_test.go | 61 ++ nnet/util.go | 60 ++ nrandom/bytes.go | 20 + nrandom/const.go | 14 + nrandom/id.go | 66 ++ nrandom/int.go | 34 + nrandom/int_test.go | 24 + nrandom/snowflake/generator.go | 51 ++ nrandom/snowflake/options.go | 45 ++ nrandom/snowflake/snowflake.go | 37 + nrandom/snowflake/snowflake_offset.go | 182 +++++ nrandom/string.go | 50 ++ nrandom/uuid.go | 73 ++ nrandom/uuid_test.go | 62 ++ nreflect/check.go | 158 +++++ nreflect/convert.go | 195 ++++++ nreflect/util.go | 207 ++++++ nstd/chan.go | 63 ++ nstd/check.go | 120 ++++ nstd/gofunc.go | 90 +++ nstd/io/writer.go | 20 + nstd/io/writer_wrapper.go | 48 ++ nstd/tea/tea.go | 491 +++++++++++++ nstr/ac/README.md | 85 +++ nstr/ac/ahocorasick.go | 386 +++++++++++ nstr/ac/automaton.go | 222 ++++++ nstr/ac/byte_frequencies.go | 260 +++++++ nstr/ac/classes.go | 77 +++ nstr/ac/dfa.go | 729 ++++++++++++++++++++ nstr/ac/nfa.go | 822 ++++++++++++++++++++++ nstr/ac/prefilter.go | 601 ++++++++++++++++ nstr/ac/util.go | 9 + nstr/check.go | 406 +++++++++++ nstr/codec.go | 71 ++ nstr/convert.go | 311 +++++++++ nstr/filter.go | 57 ++ nstr/match.go | 131 ++++ nstr/padding.go | 130 ++++ nstr/parser.go | 140 ++++ nstr/repeat.go | 47 ++ nstr/runes.go | 78 +++ nstr/split.go | 159 +++++ nstr/textutil/textutil.go | 63 ++ nstr/textutil/textutil_test.go | 147 ++++ nstr/textutil/var_replacer.go | 212 ++++++ nstr/util.go | 333 +++++++++ nstruct/check.go | 14 + nstruct/init.go | 178 +++++ nstruct/mapconv.go | 173 +++++ nstruct/tags.go | 273 ++++++++ nsys/atomic/atomic_duration.go | 36 + nsys/atomic/atomic_int64.go | 60 ++ nsys/clipboard/clipboard.go | 210 ++++++ nsys/clipboard/clipboard_test.go | 40 ++ nsys/clipboard/testdata/read-from-cb.txt | 1 + nsys/clipboard/testdata/testcb.txt | 1 + nsys/clipboard/util.go | 87 +++ nsys/clipboard/util_darwin.go | 15 + nsys/clipboard/util_unix.go | 20 + nsys/clipboard/util_windows.go | 15 + nsys/cmdn/cmd_darwin.go | 21 + nsys/cmdn/cmd_freebsd.go | 21 + nsys/cmdn/cmd_linux.go | 21 + nsys/cmdn/cmd_windows.go | 23 + nsys/cmdn/command.go | 18 + nsys/cmdn/options.go | 49 ++ nsys/cmdn/proc.go | 364 ++++++++++ nsys/cmdn/response.go | 40 ++ nsys/cmdn/serializer_plain.go | 48 ++ nsys/cmdr/cmd.go | 449 ++++++++++++ nsys/cmdr/cmd_test.go | 35 + nsys/cmdr/cmdr.go | 32 + nsys/cmdr/runner.go | 316 +++++++++ nsys/cmdr/runner_test.go | 41 ++ nsys/exec.go | 71 ++ nsys/retry/retry.go | 87 +++ nsys/retry/retry_test.go | 73 ++ nsys/sysenv.go | 184 +++++ nsys/sysutil.go | 47 ++ nsys/sysutil_darwin.go | 36 + nsys/sysutil_nonwin.go | 17 + nsys/sysutil_test.go | 23 + nsys/sysutil_unix.go | 54 ++ nsys/sysutil_windows.go | 56 ++ ntest/assert/assert.go | 8 + ntest/assert/assertions.go | 30 + ntest/assert/assertions_methods.go | 243 +++++++ ntest/assert/assertions_test.go | 43 ++ ntest/assert/asserts.go | 730 ++++++++++++++++++++ ntest/assert/asserts_test.go | 34 + ntest/assert/util.go | 190 +++++ ntest/buffer.go | 43 ++ ntest/buffer_test.go | 20 + ntest/mock/env.go | 118 ++++ ntest/mock/env_test.go | 55 ++ ntest/mock/fs.go | 42 ++ ntest/writer.go | 63 ++ ntest/writer_test.go | 28 + ntime/check.go | 1 + ntime/config.go | 11 + ntime/format.go | 240 +++++++ ntime/gotime.go | 115 ++++ ntime/ntime.go | 344 ++++++++++ ntime/template.go | 43 ++ ntime/ticker.go | 78 +++ ntime/util.go | 235 +++++++ 255 files changed, 28983 insertions(+) create mode 100644 .gitignore create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/common/cli.go create mode 100644 internal/common/cli_nonwin.go create mode 100644 internal/common/cli_windows.go create mode 100644 internal/common/env.go create mode 100644 internal/common/sys.go create mode 100644 internal/convert/bool.go create mode 100644 narr/check.go create mode 100644 narr/check_test.go create mode 100644 narr/collection.go create mode 100644 narr/collection_test.go create mode 100644 narr/convert.go create mode 100644 narr/convert_test.go create mode 100644 narr/format.go create mode 100644 narr/format_test.go create mode 100644 narr/types.go create mode 100644 narr/types_test.go create mode 100644 narr/util.go create mode 100644 narr/util_test.go create mode 100644 nbyte/buffer.go create mode 100644 nbyte/buffer_test.go create mode 100644 nbyte/bytex.go create mode 100644 nbyte/check.go create mode 100644 nbyte/encoder.go create mode 100644 nbyte/encoder_test.go create mode 100644 nbyte/util.go create mode 100644 nbyte/util_test.go create mode 100644 ncli/cmdline/builder.go create mode 100644 ncli/cmdline/builder_test.go create mode 100644 ncli/cmdline/cmdline.go create mode 100644 ncli/cmdline/parser.go create mode 100644 ncli/cmdline/parser_test.go create mode 100644 ncli/color_print.go create mode 100644 ncli/info.go create mode 100644 ncli/info_nonwin.go create mode 100644 ncli/info_windows.go create mode 100644 ncli/read.go create mode 100644 ncli/read_nonwin.go create mode 100644 ncli/read_test.go create mode 100644 ncli/read_windows.go create mode 100644 ncli/util.go create mode 100644 ncli/util_test.go create mode 100644 ncrypt/aes_des.go create mode 100644 ncrypt/base64.go create mode 100644 ncrypt/bcrypt.go create mode 100644 ncrypt/hmac.go create mode 100644 ncrypt/md5.go create mode 100644 ncrypt/rsa.go create mode 100644 ncrypt/sha.go create mode 100644 ndef/consts.go create mode 100644 ndef/errors.go create mode 100644 ndef/formatter.go create mode 100644 ndef/serializer.go create mode 100644 ndef/symbols.go create mode 100644 ndef/types.go create mode 100644 nenv/info.go create mode 100644 nenv/parse.go create mode 100644 nfs/check.go create mode 100644 nfs/check_test.go create mode 100644 nfs/find.go create mode 100644 nfs/find_test.go create mode 100644 nfs/finder/README.md create mode 100644 nfs/finder/config.go create mode 100644 nfs/finder/elem.go create mode 100644 nfs/finder/finder.go create mode 100644 nfs/finder/finder_test.go create mode 100644 nfs/finder/matcher.go create mode 100644 nfs/finder/matchers.go create mode 100644 nfs/finder/matchers_test.go create mode 100644 nfs/finder/testdata/.dotdir/some.txt create mode 100644 nfs/finder/testdata/.env create mode 100644 nfs/finder/testdata/test.txt create mode 100644 nfs/fn.go create mode 100644 nfs/info.go create mode 100644 nfs/info_nonwin.go create mode 100644 nfs/info_test.go create mode 100644 nfs/info_windows.go create mode 100644 nfs/oper.go create mode 100644 nfs/oper_read.go create mode 100644 nfs/oper_read_test.go create mode 100644 nfs/oper_test.go create mode 100644 nfs/oper_write.go create mode 100644 nfs/oper_write_test.go create mode 100644 nfs/testdata/.dotdir/some.txt create mode 100644 nfs/testdata/.env create mode 100644 nfs/testdata/cp-file-dst.txt create mode 100644 nfs/testdata/cp-file-src.txt create mode 100644 nfs/testdata/get-contents.txt create mode 100644 nfs/testdata/mimetext.txt create mode 100644 nfs/util.go create mode 100644 nfs/util_nonwin.go create mode 100644 nfs/util_nonwin_test.go create mode 100644 nfs/util_test.go create mode 100644 nfs/util_windows_test.go create mode 100644 ngo/base_fn.go create mode 100644 ngo/codec/serializer_json.go create mode 100644 ngo/ext_fn.go create mode 100644 nlog/color.go create mode 100644 nlog/color_test.go create mode 100644 nlog/config.go create mode 100644 nlog/fields.go create mode 100644 nlog/fields_test.go create mode 100644 nlog/lesslogger.go create mode 100644 nlog/lesslogger_test.go create mode 100644 nlog/lesswriter.go create mode 100644 nlog/lesswriter_test.go create mode 100644 nlog/limitedexecutor.go create mode 100644 nlog/limitedexecutor_test.go create mode 100644 nlog/logger.go create mode 100644 nlog/logs.go create mode 100644 nlog/logs_test.go create mode 100644 nlog/logwriter.go create mode 100644 nlog/richlogger.go create mode 100644 nlog/richlogger_test.go create mode 100644 nlog/rotatelogger.go create mode 100644 nlog/rotatelogger_test.go create mode 100644 nlog/syslog.go create mode 100644 nlog/syslog_test.go create mode 100644 nlog/util.go create mode 100644 nlog/util_test.go create mode 100644 nlog/vars.go create mode 100644 nlog/writer.go create mode 100644 nlog/writer_test.go create mode 100644 nmap/check.go create mode 100644 nmap/check_test.go create mode 100644 nmap/convert.go create mode 100644 nmap/convert_test.go create mode 100644 nmap/data.go create mode 100644 nmap/data_test.go create mode 100644 nmap/errors.go create mode 100644 nmap/format.go create mode 100644 nmap/format_test.go create mode 100644 nmap/get.go create mode 100644 nmap/get_test.go create mode 100644 nmap/setval.go create mode 100644 nmap/setval_test.go create mode 100644 nmap/smap.go create mode 100644 nmap/smap_test.go create mode 100644 nmap/util.go create mode 100644 nmap/util_test.go create mode 100644 nmath/check.go create mode 100644 nmath/check_test.go create mode 100644 nmath/convert.go create mode 100644 nmath/convert_test.go create mode 100644 nmath/number.go create mode 100644 nmath/number_test.go create mode 100644 nmath/util.go create mode 100644 nmath/util_test.go create mode 100644 nnet/util.go create mode 100644 nrandom/bytes.go create mode 100644 nrandom/const.go create mode 100644 nrandom/id.go create mode 100644 nrandom/int.go create mode 100644 nrandom/int_test.go create mode 100644 nrandom/snowflake/generator.go create mode 100644 nrandom/snowflake/options.go create mode 100644 nrandom/snowflake/snowflake.go create mode 100644 nrandom/snowflake/snowflake_offset.go create mode 100644 nrandom/string.go create mode 100644 nrandom/uuid.go create mode 100644 nrandom/uuid_test.go create mode 100644 nreflect/check.go create mode 100644 nreflect/convert.go create mode 100644 nreflect/util.go create mode 100644 nstd/chan.go create mode 100644 nstd/check.go create mode 100644 nstd/gofunc.go create mode 100644 nstd/io/writer.go create mode 100644 nstd/io/writer_wrapper.go create mode 100644 nstd/tea/tea.go create mode 100644 nstr/ac/README.md create mode 100644 nstr/ac/ahocorasick.go create mode 100644 nstr/ac/automaton.go create mode 100644 nstr/ac/byte_frequencies.go create mode 100644 nstr/ac/classes.go create mode 100644 nstr/ac/dfa.go create mode 100644 nstr/ac/nfa.go create mode 100644 nstr/ac/prefilter.go create mode 100644 nstr/ac/util.go create mode 100644 nstr/check.go create mode 100644 nstr/codec.go create mode 100644 nstr/convert.go create mode 100644 nstr/filter.go create mode 100644 nstr/match.go create mode 100644 nstr/padding.go create mode 100644 nstr/parser.go create mode 100644 nstr/repeat.go create mode 100644 nstr/runes.go create mode 100644 nstr/split.go create mode 100644 nstr/textutil/textutil.go create mode 100644 nstr/textutil/textutil_test.go create mode 100644 nstr/textutil/var_replacer.go create mode 100644 nstr/util.go create mode 100644 nstruct/check.go create mode 100644 nstruct/init.go create mode 100644 nstruct/mapconv.go create mode 100644 nstruct/tags.go create mode 100644 nsys/atomic/atomic_duration.go create mode 100644 nsys/atomic/atomic_int64.go create mode 100644 nsys/clipboard/clipboard.go create mode 100644 nsys/clipboard/clipboard_test.go create mode 100644 nsys/clipboard/testdata/read-from-cb.txt create mode 100644 nsys/clipboard/testdata/testcb.txt create mode 100644 nsys/clipboard/util.go create mode 100644 nsys/clipboard/util_darwin.go create mode 100644 nsys/clipboard/util_unix.go create mode 100644 nsys/clipboard/util_windows.go create mode 100644 nsys/cmdn/cmd_darwin.go create mode 100644 nsys/cmdn/cmd_freebsd.go create mode 100644 nsys/cmdn/cmd_linux.go create mode 100644 nsys/cmdn/cmd_windows.go create mode 100644 nsys/cmdn/command.go create mode 100644 nsys/cmdn/options.go create mode 100644 nsys/cmdn/proc.go create mode 100644 nsys/cmdn/response.go create mode 100644 nsys/cmdn/serializer_plain.go create mode 100644 nsys/cmdr/cmd.go create mode 100644 nsys/cmdr/cmd_test.go create mode 100644 nsys/cmdr/cmdr.go create mode 100644 nsys/cmdr/runner.go create mode 100644 nsys/cmdr/runner_test.go create mode 100644 nsys/exec.go create mode 100644 nsys/retry/retry.go create mode 100644 nsys/retry/retry_test.go create mode 100644 nsys/sysenv.go create mode 100644 nsys/sysutil.go create mode 100644 nsys/sysutil_darwin.go create mode 100644 nsys/sysutil_nonwin.go create mode 100644 nsys/sysutil_test.go create mode 100644 nsys/sysutil_unix.go create mode 100644 nsys/sysutil_windows.go create mode 100644 ntest/assert/assert.go create mode 100644 ntest/assert/assertions.go create mode 100644 ntest/assert/assertions_methods.go create mode 100644 ntest/assert/assertions_test.go create mode 100644 ntest/assert/asserts.go create mode 100644 ntest/assert/asserts_test.go create mode 100644 ntest/assert/util.go create mode 100644 ntest/buffer.go create mode 100644 ntest/buffer_test.go create mode 100644 ntest/mock/env.go create mode 100644 ntest/mock/env_test.go create mode 100644 ntest/mock/fs.go create mode 100644 ntest/writer.go create mode 100644 ntest/writer_test.go create mode 100644 ntime/check.go create mode 100644 ntime/config.go create mode 100644 ntime/format.go create mode 100644 ntime/gotime.go create mode 100644 ntime/ntime.go create mode 100644 ntime/template.go create mode 100644 ntime/ticker.go create mode 100644 ntime/util.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f9e2e1a --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# Binaries for programs and plugins +*.exe +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Project-local glide cache, RE: https://github.com/Masterminds/glide/issues/736 +.glide/ + +.idea/ +.vscode/ + +*/logs/ +logs/ +*.log \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..15176dd --- /dev/null +++ b/go.mod @@ -0,0 +1,24 @@ +module git.noahlan.cn/noahlan/ntool + +go 1.20 + +require ( + github.com/gofrs/uuid/v5 v5.0.0 + github.com/gookit/color v1.5.3 + github.com/mattn/go-colorable v0.1.13 + go.opentelemetry.io/otel v1.16.0 + go.opentelemetry.io/otel/sdk v1.16.0 + go.opentelemetry.io/otel/trace v1.16.0 + golang.org/x/crypto v0.10.0 + golang.org/x/sys v0.9.0 + golang.org/x/term v0.9.0 + golang.org/x/text v0.10.0 +) + +require ( + github.com/go-logr/logr v1.2.4 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect + go.opentelemetry.io/otel/metric v1.16.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b46366e --- /dev/null +++ b/go.sum @@ -0,0 +1,40 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= +github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M= +github.com/gofrs/uuid/v5 v5.0.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/gookit/color v1.5.3 h1:twfIhZs4QLCtimkP7MOxlF3A0U/5cDPseRT9M/+2SCE= +github.com/gookit/color v1.5.3/go.mod h1:NUzwzeehUfl7GIb36pqId+UGmRfQcU/WiiyTTeNjHtE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s= +go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4= +go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo= +go.opentelemetry.io/otel/metric v1.16.0/go.mod h1:QE47cpOmkwipPiefDwo2wDzwJrlfxxNYodqc4xnGCo4= +go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiMWgE= +go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= +go.opentelemetry.io/otel/trace v1.16.0 h1:8JRpaObFoW0pxuVPapkgH8UhHQj+bJW8jJsCZEu5MQs= +go.opentelemetry.io/otel/trace v1.16.0/go.mod h1:Yt9vYq1SdNz3xdjZZK7wcXv1qv2pwLkqr2QVwea0ef0= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= +golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28= +golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/common/cli.go b/internal/common/cli.go new file mode 100644 index 0000000..90958e2 --- /dev/null +++ b/internal/common/cli.go @@ -0,0 +1,70 @@ +package common + +import ( + "os" + "os/exec" + "path/filepath" + "strings" +) + +// ExecCmd a command and return output. +// +// Usage: +// +// ExecCmd("ls", []string{"-al"}) +func ExecCmd(binName string, args []string, workDir ...string) (string, error) { + // create a new Cmd instance + cmd := exec.Command(binName, args...) + if len(workDir) > 0 { + cmd.Dir = workDir[0] + } + + bs, err := cmd.Output() + return string(bs), err +} + +// curShell cache +var curShell string + +// CurrentShell get current used shell env file. +// +// eg "/bin/zsh" "/bin/bash". +// if onlyName=true, will return "zsh", "bash" +func CurrentShell(onlyName bool) (binPath string) { + var err error + if curShell == "" { + binPath = os.Getenv("SHELL") + if len(binPath) == 0 { + binPath, err = ShellExec("echo $SHELL") + if err != nil { + return "" + } + } + + binPath = strings.TrimSpace(binPath) + // cache result + curShell = binPath + } else { + binPath = curShell + } + + if onlyName && len(binPath) > 0 { + binPath = filepath.Base(binPath) + } + return +} + +// HasShellEnv has shell env check. +// +// Usage: +// +// HasShellEnv("sh") +// HasShellEnv("bash") +func HasShellEnv(shell string) bool { + // can also use: "echo $0" + out, err := ShellExec("echo OK", shell) + if err != nil { + return false + } + return strings.TrimSpace(out) == "OK" +} diff --git a/internal/common/cli_nonwin.go b/internal/common/cli_nonwin.go new file mode 100644 index 0000000..f2e8255 --- /dev/null +++ b/internal/common/cli_nonwin.go @@ -0,0 +1,28 @@ +//go:build !windows + +package common + +import ( + "bytes" + "os/exec" +) + +// ShellExec exec command by shell +// cmdLine e.g. "ls -al" +func ShellExec(cmdLine string, shells ...string) (string, error) { + // shell := "/bin/sh" + shell := "sh" + if len(shells) > 0 { + shell = shells[0] + } + + var out bytes.Buffer + + cmd := exec.Command(shell, "-c", cmdLine) + cmd.Stdout = &out + + if err := cmd.Run(); err != nil { + return "", err + } + return out.String(), nil +} diff --git a/internal/common/cli_windows.go b/internal/common/cli_windows.go new file mode 100644 index 0000000..9495c0b --- /dev/null +++ b/internal/common/cli_windows.go @@ -0,0 +1,28 @@ +//go:build windows + +package common + +import ( + "bytes" + "os/exec" +) + +// ShellExec exec command by shell +// cmdLine e.g. "ls -al" +func ShellExec(cmdLine string, shells ...string) (string, error) { + // shell := "/bin/sh" + shell := "cmd" + if len(shells) > 0 { + shell = shells[0] + } + + var out bytes.Buffer + + cmd := exec.Command(shell, "/c", cmdLine) + cmd.Stdout = &out + + if err := cmd.Run(); err != nil { + return "", err + } + return out.String(), nil +} diff --git a/internal/common/env.go b/internal/common/env.go new file mode 100644 index 0000000..d26ac01 --- /dev/null +++ b/internal/common/env.go @@ -0,0 +1,78 @@ +package common + +import ( + "os" + "regexp" + "strings" +) + +// Environ like os.Environ, but will returns key-value map[string]string data. +func Environ() map[string]string { + envList := os.Environ() + envMap := make(map[string]string, len(envList)) + + for _, str := range envList { + nodes := strings.SplitN(str, "=", 2) + + if len(nodes) < 2 { + envMap[nodes[0]] = "" + } else { + envMap[nodes[0]] = nodes[1] + } + } + return envMap +} + +// parse env value, allow: +// +// only key - "${SHELL}" +// with default - "${NotExist | defValue}" +// multi key - "${GOPATH}/${APP_ENV | prod}/dir" +// +// Notice: +// +// must add "?" - To ensure that there is no greedy match +// var envRegex = regexp.MustCompile(`\${[\w-| ]+}`) +var envRegex = regexp.MustCompile(`\${.+?}`) + +// ParseEnvVar parse ENV var value from input string, support default value. +// +// Format: +// +// ${var_name} Only var name +// ${var_name | default} With default value +// +// Usage: +// +// comfunc.ParseEnvVar("${ APP_NAME }") +// comfunc.ParseEnvVar("${ APP_ENV | dev }") +func ParseEnvVar(val string, getFn func(string) string) (newVal string) { + if !strings.Contains(val, "${") { + return val + } + + // default use os.Getenv + if getFn == nil { + getFn = os.Getenv + } + + var name, def string + return envRegex.ReplaceAllStringFunc(val, func(eVar string) string { + // eVar like "${NotExist|defValue}", first remove "${" and "}", then split it + ss := strings.SplitN(eVar[2:len(eVar)-1], "|", 2) + + // with default value. ${NotExist|defValue} + if len(ss) == 2 { + name, def = strings.TrimSpace(ss[0]), strings.TrimSpace(ss[1]) + } else { + name = strings.TrimSpace(ss[0]) + } + + // get ENV value by name + eVal := getFn(name) + if eVal == "" { + eVal = def + } + return eVal + }) +} diff --git a/internal/common/sys.go b/internal/common/sys.go new file mode 100644 index 0000000..2635280 --- /dev/null +++ b/internal/common/sys.go @@ -0,0 +1,30 @@ +package common + +import "os" + +// Workdir get +func Workdir() string { + dir, _ := os.Getwd() + return dir +} + +// ExpandHome will parse first `~` as user home dir path. +func ExpandHome(pathStr string) string { + if len(pathStr) == 0 { + return pathStr + } + + if pathStr[0] != '~' { + return pathStr + } + + if len(pathStr) > 1 && pathStr[1] != '/' && pathStr[1] != '\\' { + return pathStr + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return pathStr + } + return homeDir + pathStr[1:] +} diff --git a/internal/convert/bool.go b/internal/convert/bool.go new file mode 100644 index 0000000..7220dd5 --- /dev/null +++ b/internal/convert/bool.go @@ -0,0 +1,38 @@ +package convert + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/ndef" + "strings" +) + +// Bool try to convert type to bool +func Bool(v any) bool { + bl, _ := ToBool(v) + return bl +} + +// ToBool try to convert type to bool +func ToBool(v any) (bool, error) { + if bl, ok := v.(bool); ok { + return bl, nil + } + + if str, ok := v.(string); ok { + return StrToBool(str) + } + return false, ndef.ErrConvType +} + +// StrToBool parse string to bool. like strconv.ParseBool() +func StrToBool(s string) (bool, error) { + lower := strings.ToLower(s) + switch lower { + case "1", "on", "yes", "true": + return true, nil + case "0", "off", "no", "false": + return false, nil + } + + return false, fmt.Errorf("'%s' cannot convert to bool", s) +} diff --git a/narr/check.go b/narr/check.go new file mode 100644 index 0000000..19536c2 --- /dev/null +++ b/narr/check.go @@ -0,0 +1,105 @@ +package narr + +import ( + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nmath" + "reflect" + "strings" +) + +// NotIn check the given value whether not in the list +func NotIn[T ndef.ScalarType](list []T, value T) bool { + return !In(list, value) +} + +// In check the given value whether in the list +func In[T ndef.ScalarType](list []T, value T) bool { + for _, elem := range list { + if elem == value { + return true + } + } + return false +} + +// ContainsAll check given values is sub-list of sample list. +func ContainsAll[T ndef.ScalarType](list, values []T) bool { + return IsSubList(values, list) +} + +// IsSubList check given values is sub-list of sample list. +func IsSubList[T ndef.ScalarType](values, list []T) bool { + for _, value := range values { + if !In(list, value) { + return false + } + } + return true +} + +// IsParent check given values is parent-list of samples. +func IsParent[T ndef.ScalarType](values, list []T) bool { + return IsSubList(list, values) +} + +// StringsHas check the []string contains the given element +func StringsHas(ss []string, val string) bool { + return In(ss, val) +} + +// IntsHas check the []int contains the given value +func IntsHas(ints []int, val int) bool { + return In(ints, val) +} + +// Int64sHas check the []int64 contains the given value +func Int64sHas(ints []int64, val int64) bool { + return In(ints, val) +} + +// HasValue check array(strings, intXs, uintXs) should be contained the given value(int(X),string). +func HasValue(arr, val any) bool { return Contains(arr, val) } + +// Contains check slice/array(strings, intXs, uintXs) should be contained the given value(int(X),string). +// +// TIP: Difference the In(), Contains() will try to convert value type, +// and Contains() support array type. +func Contains(arr, val any) bool { + if val == nil || arr == nil { + return false + } + + // if is string value + if strVal, ok := val.(string); ok { + if ss, ok := arr.([]string); ok { + return StringsHas(ss, strVal) + } + + rv := reflect.ValueOf(arr) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + for i := 0; i < rv.Len(); i++ { + if v, ok := rv.Index(i).Interface().(string); ok && strings.EqualFold(v, strVal) { + return true + } + } + } + + return false + } + + // as int value + intVal, err := nmath.Int64(val) + if err != nil { + return false + } + + if int64s, err := ToInt64s(arr); err == nil { + return Int64sHas(int64s, intVal) + } + return false +} + +// NotContains check array(strings, ints, uints) should be not contains the given value. +func NotContains(arr, val any) bool { + return !Contains(arr, val) +} diff --git a/narr/check_test.go b/narr/check_test.go new file mode 100644 index 0000000..4659652 --- /dev/null +++ b/narr/check_test.go @@ -0,0 +1,108 @@ +package narr_test + +import ( + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestIntsHas(t *testing.T) { + ints := []int{2, 4, 5} + assert.True(t, narr.IntsHas(ints, 2)) + assert.True(t, narr.IntsHas(ints, 5)) + assert.False(t, narr.IntsHas(ints, 3)) +} + +func TestInt64sHas(t *testing.T) { + ints := []int64{2, 4, 5} + assert.True(t, narr.Int64sHas(ints, 2)) + assert.True(t, narr.Int64sHas(ints, 5)) + assert.False(t, narr.Int64sHas(ints, 3)) +} + +func TestStringsHas(t *testing.T) { + ss := []string{"a", "b"} + assert.True(t, narr.StringsHas(ss, "a")) + assert.True(t, narr.StringsHas(ss, "b")) + + assert.False(t, narr.StringsHas(ss, "c")) +} + +func TestInAndNotIn(t *testing.T) { + is := assert.New(t) + + arr := []int{1, 2, 3} + is.True(narr.In(arr, 2)) + is.False(narr.NotIn(arr, 2)) + + arr1 := []rune{'a', 'b'} + is.True(narr.In(arr1, 'b')) + is.False(narr.NotIn(arr1, 'b')) + + arr2 := []string{"a", "b", "c"} + is.True(narr.In(arr2, "b")) + is.False(narr.NotIn(arr2, "b")) +} + +func TestContainsAll(t *testing.T) { + is := assert.New(t) + + arr := []int{1, 2, 3} + is.True(narr.ContainsAll(arr, []int{2})) + is.False(narr.ContainsAll(arr, []int{2, 45})) + is.True(narr.IsParent(arr, []int{2})) + + arr2 := []string{"a", "b", "c"} + is.True(narr.ContainsAll(arr2, []string{"b"})) + is.False(narr.ContainsAll(arr2, []string{"b", "e"})) + is.True(narr.IsParent(arr2, []string{"b"})) +} + +func TestContains(t *testing.T) { + is := assert.New(t) + tests := map[any]any{ + 1: []int{1, 2, 3}, + 2: []int8{1, 2, 3}, + 3: []int16{1, 2, 3}, + 4: []int32{4, 2, 3}, + 5: []int64{5, 2, 3}, + 6: []uint{6, 2, 3}, + 7: []uint8{7, 2, 3}, + 8: []uint16{8, 2, 3}, + 9: []uint32{9, 2, 3}, + 10: []uint64{10, 3}, + 11: []string{"11", "3"}, + 'a': []int64{97}, + 'b': []rune{'a', 'b'}, + 'c': []byte{'a', 'b', 'c'}, // byte -> uint8 + "a": []string{"a", "b", "c"}, + 12: [5]uint{12, 1, 2, 3, 4}, + 'A': [3]rune{'A', 'B', 'C'}, + 'd': [4]byte{'a', 'b', 'c', 'd'}, + "aa": [3]string{"aa", "bb", "cc"}, + } + + for val, list := range tests { + is.True(narr.Contains(list, val)) + is.False(narr.NotContains(list, val)) + } + + is.False(narr.Contains(nil, []int{})) + is.False(narr.Contains('a', []int{})) + // + is.False(narr.Contains([]int{2, 3}, []int{2})) + is.False(narr.Contains([]int{2, 3}, "a")) + is.False(narr.Contains([]string{"a", "b"}, 12)) + is.False(narr.Contains(nil, 12)) + is.False(narr.Contains(map[int]int{2: 3}, 12)) + + tests1 := map[any]any{ + 2: []int{1, 3}, + "a": []string{"b", "c"}, + } + + for val, list := range tests1 { + is.True(narr.NotContains(list, val)) + is.False(narr.Contains(list, val)) + } +} diff --git a/narr/collection.go b/narr/collection.go new file mode 100644 index 0000000..00a8561 --- /dev/null +++ b/narr/collection.go @@ -0,0 +1,427 @@ +package narr + +import ( + "errors" + "reflect" +) + +// ErrElementNotFound is the error returned when the element is not found. +var ErrElementNotFound = errors.New("element not found") + +// Comparer Function to compare two elements. +type Comparer func(a, b any) int + +// Predicate Function to predicate a struct/value satisfies a condition. +type Predicate func(a any) bool + +var ( + // StringEqualsComparer Comparer for string. It will compare the string by their value. + // returns: 0 if equal, -1 if a != b + StringEqualsComparer Comparer = func(a, b any) int { + typeOfA := reflect.TypeOf(a) + if typeOfA.Kind() == reflect.Ptr { + typeOfA = typeOfA.Elem() + } + + typeOfB := reflect.TypeOf(b) + if typeOfB.Kind() == reflect.Ptr { + typeOfB = typeOfB.Elem() + } + + if typeOfA != typeOfB { + return -1 + } + + strA := "" + strB := "" + + if val, ok := a.(string); ok { + strA = val + } else if val, ok := a.(*string); ok { + strA = *val + } else { + return -1 + } + + if val, ok := b.(string); ok { + strB = val + } else if val, ok := b.(*string); ok { + strB = *val + } else { + return -1 + } + + if strA == strB { + return 0 + } + return -1 + } + + // ReferenceEqualsComparer Comparer for strcut ptr. It will compare the struct by their ptr addr. + // returns: 0 if equal, -1 if a != b + ReferenceEqualsComparer Comparer = func(a, b any) int { + if a == b { + return 0 + } + return -1 + } + + // ElemTypeEqualsComparer Comparer for struct/value. It will compare the struct by their element type (reflect.Type.Elem()). + // returns: 0 if same type, -1 if not. + ElemTypeEqualsComparer Comparer = func(a, b any) int { + at := reflect.TypeOf(a) + bt := reflect.TypeOf(b) + if at.Kind() == reflect.Ptr { + at = at.Elem() + } + + if bt.Kind() == reflect.Ptr { + bt = bt.Elem() + } + + if at == bt { + return 0 + } + return -1 + } +) + +// TwoWaySearch Find specialized element in a slice forward and backward in the same time, should be more quickly. +// +// data: the slice to search in. MUST BE A SLICE. +// item: the element to search. +// fn: the comparer function. +// return: the index of the element, or -1 if not found. +func TwoWaySearch(data any, item any, fn Comparer) (int, error) { + if data == nil { + return -1, errors.New("collections.TwowaySearch: data is nil") + } + if fn == nil { + return -1, errors.New("collections.TwowaySearch: fn is nil") + } + + dataType := reflect.TypeOf(data) + if dataType.Kind() != reflect.Slice { + return -1, errors.New("collections.TwowaySearch: data is not a slice") + } + + dataVal := reflect.ValueOf(data) + if dataVal.Len() == 0 { + return -1, errors.New("collections.TwowaySearch: data is empty") + } + itemType := dataType.Elem() + if itemType.Kind() == reflect.Ptr { + itemType = itemType.Elem() + } + + if itemType != dataVal.Index(0).Type() { + return -1, errors.New("collections.TwowaySearch: item type is not the same as data type") + } + + forward := 0 + backward := dataVal.Len() - 1 + + for forward <= backward { + forwardVal := dataVal.Index(forward).Interface() + if fn(forwardVal, item) == 0 { + return forward, nil + } + + backwardVal := dataVal.Index(backward).Interface() + if fn(backwardVal, item) == 0 { + return backward, nil + } + + forward++ + backward-- + } + + return -1, ErrElementNotFound +} + +// MakeEmptySlice Create a new slice with the elements of the source that satisfy the predicate. +// +// itemType: the type of the elements in the source. +// returns: the new slice. +func MakeEmptySlice(itemType reflect.Type) any { + ret := reflect.MakeSlice(reflect.SliceOf(itemType), 0, 0).Interface() + return ret +} + +// CloneSlice Clone a slice. +// +// data: the slice to clone. +// returns: the cloned slice. +func CloneSlice(data any) any { + typeOfData := reflect.TypeOf(data) + if typeOfData.Kind() != reflect.Slice { + panic("collections.CloneSlice: data must be a slice") + } + return reflect.AppendSlice(reflect.New(reflect.SliceOf(typeOfData.Elem())).Elem(), reflect.ValueOf(data)).Interface() +} + +// Differences Produces the set difference of two slice according to a comparer function. +// +// first: the first slice. MUST BE A SLICE. +// second: the second slice. MUST BE A SLICE. +// fn: the comparer function. +// returns: the difference of the two slices. +func Differences[T any](first, second []T, fn Comparer) []T { + typeOfFirst := reflect.TypeOf(first) + if typeOfFirst.Kind() != reflect.Slice { + panic("collections.Excepts: first must be a slice") + } + + typeOfSecond := reflect.TypeOf(second) + if typeOfSecond.Kind() != reflect.Slice { + panic("collections.Excepts: second must be a slice") + } + + firstLen := len(first) + if firstLen == 0 { + return CloneSlice(second).([]T) + } + + secondLen := len(second) + if secondLen == 0 { + return CloneSlice(first).([]T) + } + + max := firstLen + if secondLen > firstLen { + max = secondLen + } + + result := make([]T, 0) + for i := 0; i < max; i++ { + if i < firstLen { + s := first[i] + if i, _ := TwoWaySearch(second, s, fn); i < 0 { + result = append(result, s) + } + } + + if i < secondLen { + t := second[i] + if i, _ := TwoWaySearch(first, t, fn); i < 0 { + result = append(result, t) + } + } + } + + return result +} + +// Excepts Produces the set difference of two slice according to a comparer function. +// +// first: the first slice. MUST BE A SLICE. +// second: the second slice. MUST BE A SLICE. +// fn: the comparer function. +// returns: the difference of the two slices. +func Excepts(first, second any, fn Comparer) any { + typeOfFirst := reflect.TypeOf(first) + if typeOfFirst.Kind() != reflect.Slice { + panic("collections.Excepts: first must be a slice") + } + valOfFirst := reflect.ValueOf(first) + if valOfFirst.Len() == 0 { + return MakeEmptySlice(typeOfFirst.Elem()) + } + + typeOfSecond := reflect.TypeOf(second) + if typeOfSecond.Kind() != reflect.Slice { + panic("collections.Excepts: second must be a slice") + } + + valOfSecond := reflect.ValueOf(second) + if valOfSecond.Len() == 0 { + return CloneSlice(first) + } + + result := reflect.New(reflect.SliceOf(typeOfFirst.Elem())).Elem() + for i := 0; i < valOfFirst.Len(); i++ { + s := valOfFirst.Index(i).Interface() + if i, _ := TwoWaySearch(second, s, fn); i < 0 { + result = reflect.Append(result, reflect.ValueOf(s)) + } + } + + return result.Interface() +} + +// Intersects Produces to intersect of two slice according to a comparer function. +// +// first: the first slice. MUST BE A SLICE. +// second: the second slice. MUST BE A SLICE. +// fn: the comparer function. +// returns: to intersect of the two slices. +func Intersects(first any, second any, fn Comparer) any { + typeOfFirst := reflect.TypeOf(first) + if typeOfFirst.Kind() != reflect.Slice { + panic("collections.Intersects: first must be a slice") + } + valOfFirst := reflect.ValueOf(first) + if valOfFirst.Len() == 0 { + return MakeEmptySlice(typeOfFirst.Elem()) + } + + typeOfSecond := reflect.TypeOf(second) + if typeOfSecond.Kind() != reflect.Slice { + panic("collections.Intersects: second must be a slice") + } + + valOfSecond := reflect.ValueOf(second) + if valOfSecond.Len() == 0 { + return MakeEmptySlice(typeOfFirst.Elem()) + } + + result := reflect.New(reflect.SliceOf(typeOfFirst.Elem())).Elem() + for i := 0; i < valOfFirst.Len(); i++ { + s := valOfFirst.Index(i).Interface() + if i, _ := TwoWaySearch(second, s, fn); i >= 0 { + result = reflect.Append(result, reflect.ValueOf(s)) + } + } + + return result.Interface() +} + +// Union Produces the set union of two slice according to a comparer function +// +// first: the first slice. MUST BE A SLICE. +// second: the second slice. MUST BE A SLICE. +// fn: the comparer function. +// returns: the union of the two slices. +func Union(first, second any, fn Comparer) any { + excepts := Excepts(second, first, fn) + + typeOfFirst := reflect.TypeOf(first) + if typeOfFirst.Kind() != reflect.Slice { + panic("collections.Intersects: first must be a slice") + } + valOfFirst := reflect.ValueOf(first) + if valOfFirst.Len() == 0 { + return CloneSlice(second) + } + + result := reflect.AppendSlice(reflect.New(reflect.SliceOf(typeOfFirst.Elem())).Elem(), valOfFirst) + result = reflect.AppendSlice(result, reflect.ValueOf(excepts)) + return result.Interface() +} + +// Find Produces the struct/value of a slice according to a predicate function. +// +// source: the slice. MUST BE A SLICE. +// fn: the predicate function. +// returns: the struct/value of the slice. +func Find(source any, fn Predicate) (any, error) { + aType := reflect.TypeOf(source) + if aType.Kind() != reflect.Slice { + panic("collections.Find: source must be a slice") + } + + sourceVal := reflect.ValueOf(source) + if sourceVal.Len() == 0 { + return nil, ErrElementNotFound + } + + for i := 0; i < sourceVal.Len(); i++ { + s := sourceVal.Index(i).Interface() + if fn(s) { + return s, nil + } + } + return nil, ErrElementNotFound +} + +// FindOrDefault Produce the struct/value f a slice to a predicate function, +// Produce default value when predicate function not found. +// +// source: the slice. MUST BE A SLICE. +// fn: the predicate function. +// defaultValue: the default value. +// returns: the struct/value of the slice. +func FindOrDefault(source any, fn Predicate, defaultValue any) any { + item, err := Find(source, fn) + if err != nil { + if errors.Is(err, ErrElementNotFound) { + return defaultValue + } + } + return item +} + +// TakeWhile Produce the set of a slice according to a predicate function, +// Produce empty slice when predicate function not matched. +// +// data: the slice. MUST BE A SLICE. +// fn: the predicate function. +// returns: the set of the slice. +func TakeWhile(data any, fn Predicate) any { + aType := reflect.TypeOf(data) + if aType.Kind() != reflect.Slice { + panic("collections.TakeWhile: data must be a slice") + } + + sourceVal := reflect.ValueOf(data) + if sourceVal.Len() == 0 { + return MakeEmptySlice(aType.Elem()) + } + + result := reflect.New(reflect.SliceOf(aType.Elem())).Elem() + for i := 0; i < sourceVal.Len(); i++ { + s := sourceVal.Index(i).Interface() + if fn(s) { + result = reflect.Append(result, reflect.ValueOf(s)) + } + } + return result.Interface() +} + +// ExceptWhile Produce the set of a slice except with a predicate function, +// Produce original slice when predicate function not match. +// +// data: the slice. MUST BE A SLICE. +// fn: the predicate function. +// returns: the set of the slice. +func ExceptWhile(data any, fn Predicate) any { + aType := reflect.TypeOf(data) + if aType.Kind() != reflect.Slice { + panic("collections.ExceptWhile: data must be a slice") + } + + sourceVal := reflect.ValueOf(data) + if sourceVal.Len() == 0 { + return MakeEmptySlice(aType.Elem()) + } + + result := reflect.New(reflect.SliceOf(aType.Elem())).Elem() + for i := 0; i < sourceVal.Len(); i++ { + s := sourceVal.Index(i).Interface() + if !fn(s) { + result = reflect.Append(result, reflect.ValueOf(s)) + } + } + return result.Interface() +} + +// type MapFn func(obj T) (target V, find bool) + +// Map a list to new list +// +// eg: mapping [object0{},object1{},...] to flatten list [object0.someKey, object1.someKey, ...] +func Map[T any, V any](list []T, mapFn func(obj T) (val V, find bool)) []V { + flatArr := make([]V, 0, len(list)) + + for _, obj := range list { + if target, ok := mapFn(obj); ok { + flatArr = append(flatArr, target) + } + } + return flatArr +} + +// Column alias of Map func +func Column[T any, V any](list []T, mapFn func(obj T) (val V, find bool)) []V { + return Map(list, mapFn) +} diff --git a/narr/collection_test.go b/narr/collection_test.go new file mode 100644 index 0000000..f8b3e16 --- /dev/null +++ b/narr/collection_test.go @@ -0,0 +1,317 @@ +package narr_test + +import ( + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +// StringEqualComparer tests +func TestStringEqualComparerShouldEquals(t *testing.T) { + assert.Eq(t, 0, narr.StringEqualsComparer("a", "a")) +} + +func TestStringEqualComparerShouldNotEquals(t *testing.T) { + assert.NotEq(t, 0, narr.StringEqualsComparer("a", "b")) +} + +func TestStringEqualComparerElementNotString(t *testing.T) { + assert.Eq(t, -1, narr.StringEqualsComparer(1, "a")) +} + +func TestStringEqualComparerPtr(t *testing.T) { + ptrVal := "a" + assert.Eq(t, 0, narr.StringEqualsComparer(&ptrVal, "a")) +} + +// ReferenceEqualsComparer tests +func TestReferenceEqualsComparerShouldEquals(t *testing.T) { + assert.Eq(t, 0, narr.ReferenceEqualsComparer(1, 1)) +} + +func TestReferenceEqualsComparerShouldNotEquals(t *testing.T) { + assert.NotEq(t, 0, narr.ReferenceEqualsComparer(1, 2)) +} + +// ElemTypeEqualCompareFunc +func TestElemTypeEqualCompareFuncShouldEquals(t *testing.T) { + assert.Eq(t, 0, narr.ElemTypeEqualsComparer(1, 2)) +} + +func TestElemTypeEqualCompareFuncShouldNotEquals(t *testing.T) { + assert.NotEq(t, 0, narr.ElemTypeEqualsComparer(1, "2")) +} + +func TestDifferencesShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + result := narr.Differences[string](data, []string{"a", "b"}, narr.StringEqualsComparer) + assert.Eq(t, []string{"c"}, result) + result = narr.Differences[string]([]string{"a", "b"}, data, narr.StringEqualsComparer) + assert.Eq(t, []string{"c"}, result) + result = narr.Differences[string]([]string{"a", "b", "d"}, data, narr.StringEqualsComparer) + assert.Eq(t, 2, len(result)) +} + +func TestExceptsShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + result := narr.Excepts(data, []string{"a", "b"}, narr.StringEqualsComparer) + assert.Eq(t, []string{"c"}, result.([]string)) +} + +func TestExceptsFirstNotSliceShouldPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } else { + t.Fail() + } + }() + narr.Excepts([1]string{"a"}, []string{"a", "b"}, narr.StringEqualsComparer) +} + +func TestExceptsSecondNotSliceShouldPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } else { + t.Fail() + } + }() + narr.Excepts([]string{"a", "b"}, [1]string{"a"}, narr.StringEqualsComparer) +} + +func TestExceptsFirstEmptyShouldReturnsEmpty(t *testing.T) { + data := []string{} + result := narr.Excepts(data, []string{"a", "b"}, narr.StringEqualsComparer).([]string) + assert.Eq(t, []string{}, result) + assert.NotSame(t, &data, &result, "should always returns new slice") +} + +func TestExceptsSecondEmptyShouldReturnsFirst(t *testing.T) { + data := []string{"a", "b"} + result := narr.Excepts(data, []string{}, narr.StringEqualsComparer).([]string) + assert.Eq(t, data, result) + assert.NotSame(t, &data, &result, "should always returns new slice") +} + +// Intersects tests +func TestIntersectsShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + result := narr.Intersects(data, []string{"a", "b"}, narr.StringEqualsComparer) + assert.Eq(t, []string{"a", "b"}, result.([]string)) +} + +func TestIntersectsFirstNotSliceShouldPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } else { + t.Fail() + } + }() + narr.Intersects([1]string{"a"}, []string{"a", "b"}, narr.StringEqualsComparer) +} + +func TestIntersectsSecondNotSliceShouldPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } else { + t.Fail() + } + }() + narr.Intersects([]string{"a", "b"}, [1]string{"a"}, narr.StringEqualsComparer) +} + +func TestIntersectsFirstEmptyShouldReturnsEmpty(t *testing.T) { + data := []string{} + second := []string{"a", "b"} + result := narr.Intersects(data, second, narr.StringEqualsComparer).([]string) + assert.Eq(t, []string{}, result) + assert.NotSame(t, &second, &result, "should always returns new slice") +} + +func TestIntersectsSecondEmptyShouldReturnsEmpty(t *testing.T) { + data := []string{"a", "b"} + second := []string{} + result := narr.Intersects(data, second, narr.StringEqualsComparer).([]string) + assert.Eq(t, []string{}, result) + assert.NotSame(t, &data, &result, "should always returns new slice") +} + +// Union tests + +func TestUnionShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + result := narr.Union(data, []string{"a", "b", "d"}, narr.StringEqualsComparer).([]string) + assert.Eq(t, []string{"a", "b", "c", "d"}, result) +} + +func TestUnionFirstNotSliceShouldPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } else { + t.Fail() + } + }() + narr.Union([1]string{"a"}, []string{"a", "b"}, narr.StringEqualsComparer) +} + +func TestUnionSecondNotSliceShouldPanic(t *testing.T) { + defer func() { + if r := recover(); r != nil { + return + } else { + t.Fail() + } + }() + + narr.Union([]string{"a", "b"}, [1]string{"a"}, narr.StringEqualsComparer) +} + +func TestUnionFirstEmptyShouldReturnsSecond(t *testing.T) { + data := []string{} + second := []string{"a", "b"} + result := narr.Union(data, second, narr.StringEqualsComparer).([]string) + assert.Eq(t, []string{"a", "b"}, result) + assert.NotSame(t, &second, &result, "should always returns new slice") +} + +func TestUnionSecondEmptyShouldReturnsFirst(t *testing.T) { + data := []string{"a", "b"} + second := []string{} + result := narr.Union(data, second, narr.StringEqualsComparer).([]string) + assert.Eq(t, data, result) + assert.NotSame(t, &data, &result, "should always returns new slice") +} + +// Find tests +func TestFindShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + + result, err := narr.Find(data, func(a any) bool { return a == "b" }) + assert.Nil(t, err) + assert.Eq(t, "b", result) + + _, err = narr.Find(data, func(a any) bool { return a == "d" }) + assert.NotNil(t, err) + assert.Eq(t, narr.ErrElementNotFound, err) + +} + +func TestFindNotSliceShouldPanic(t *testing.T) { + assert.Panics(t, func() { + _, _ = narr.Find([1]string{"a"}, func(a any) bool { return a == "b" }) + }) +} + +func TestFindEmptyReturnsErrElementNotFound(t *testing.T) { + data := []string{} + _, err := narr.Find(data, func(a any) bool { return a == "b" }) + assert.NotNil(t, err) + assert.Eq(t, narr.ErrElementNotFound, err) +} + +// FindOrDefault tests +func TestFindOrDefaultShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + + result := narr.FindOrDefault(data, func(a any) bool { return a == "b" }, "d").(string) + assert.Eq(t, "b", result) + + result = narr.FindOrDefault(data, func(a any) bool { return a == "d" }, "d").(string) + assert.Eq(t, "d", result) +} + +// TakeWhile tests +func TestTakeWhileShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + + result := narr.TakeWhile(data, func(a any) bool { return a == "b" || a == "c" }).([]string) + assert.Eq(t, []string{"b", "c"}, result) +} + +func TestTakeWhileNotSliceShouldPanic(t *testing.T) { + assert.Panics(t, func() { + narr.TakeWhile([1]string{"a"}, func(a any) bool { return a == "b" || a == "c" }) + }) +} + +func TestTakeWhileEmptyReturnsEmpty(t *testing.T) { + var data []string + result := narr.TakeWhile(data, func(a any) bool { return a == "b" || a == "c" }).([]string) + assert.Eq(t, []string{}, result) + assert.NotSame(t, &data, &result, "should always returns new slice") +} + +// ExceptWhile tests + +func TestExceptWhileShouldPassed(t *testing.T) { + data := []string{ + "a", + "b", + "c", + } + + result := narr.ExceptWhile(data, func(a any) bool { return a == "b" || a == "c" }).([]string) + assert.Eq(t, []string{"a"}, result) +} + +func TestExceptWhileNotSliceShouldPanic(t *testing.T) { + assert.Panics(t, func() { + narr.ExceptWhile([1]string{"a"}, func(a any) bool { return a == "b" || a == "c" }) + }) +} + +func TestExceptWhileEmptyReturnsEmpty(t *testing.T) { + var data []string + result := narr.ExceptWhile(data, func(a any) bool { return a == "b" || a == "c" }).([]string) + + assert.Eq(t, []string{}, result) + assert.NotSame(t, &data, &result, "should always returns new slice") +} + +func TestMap(t *testing.T) { + list1 := []map[string]any{ + {"name": "tom", "age": 23}, + {"name": "john", "age": 34}, + } + + flatArr := narr.Column(list1, func(obj map[string]any) (val any, find bool) { + return obj["age"], true + }) + + assert.NotEmpty(t, flatArr) + assert.Contains(t, flatArr, 23) + assert.Len(t, flatArr, 2) + assert.Eq(t, 34, flatArr[1]) +} diff --git a/narr/convert.go b/narr/convert.go new file mode 100644 index 0000000..3acf90f --- /dev/null +++ b/narr/convert.go @@ -0,0 +1,266 @@ +package narr + +import ( + "errors" + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nreflect" + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" + "strconv" + "strings" +) + +// ErrInvalidType error +var ErrInvalidType = errors.New("the input param type is invalid") + +/************************************************************* + * Join func for slice + *************************************************************/ + +// JoinStrings alias of strings.Join +func JoinStrings(sep string, ss ...string) string { + return strings.Join(ss, sep) +} + +// JoinSlice join []any slice to string. +func JoinSlice(sep string, arr ...any) string { + if arr == nil { + return "" + } + + var sb strings.Builder + for i, v := range arr { + if i > 0 { + sb.WriteString(sep) + } + sb.WriteString(nstr.SafeString(v)) + } + + return sb.String() +} + +/************************************************************* + * helper func for slices + *************************************************************/ + +// ToInt64s convert any(allow: array,slice) to []int64 +func ToInt64s(arr any) (ret []int64, err error) { + rv := reflect.ValueOf(arr) + if rv.Kind() != reflect.Slice && rv.Kind() != reflect.Array { + err = ErrInvalidType + return + } + + for i := 0; i < rv.Len(); i++ { + i64, err := nmath.Int64(rv.Index(i).Interface()) + if err != nil { + return []int64{}, err + } + + ret = append(ret, i64) + } + return +} + +// MustToInt64s convert any(allow: array,slice) to []int64 +func MustToInt64s(arr any) []int64 { + ret, _ := ToInt64s(arr) + return ret +} + +// SliceToInt64s convert []any to []int64 +func SliceToInt64s(arr []any) []int64 { + i64s := make([]int64, len(arr)) + for i, v := range arr { + i64s[i] = nmath.QuietInt64(v) + } + return i64s +} + +// StringsAsInts convert and ignore error +func StringsAsInts(ss []string) []int { + ints, _ := StringsTryInts(ss) + return ints +} + +// StringsToInts string slice to int slice +func StringsToInts(ss []string) (ints []int, err error) { + return StringsTryInts(ss) +} + +// StringsTryInts string slice to int slice +func StringsTryInts(ss []string) (ints []int, err error) { + for _, str := range ss { + iVal, err := strconv.Atoi(str) + if err != nil { + return nil, err + } + + ints = append(ints, iVal) + } + return +} + +// AnyToSlice convert any(allow: array,slice) to []any +func AnyToSlice(sl any) (ls []any, err error) { + rfKeys := reflect.ValueOf(sl) + if rfKeys.Kind() != reflect.Slice && rfKeys.Kind() != reflect.Array { + return nil, ErrInvalidType + } + + for i := 0; i < rfKeys.Len(); i++ { + ls = append(ls, rfKeys.Index(i).Interface()) + } + return +} + +// AnyToStrings convert array or slice to []string +func AnyToStrings(arr any) []string { + ret, _ := ToStrings(arr) + return ret +} + +// MustToStrings convert array or slice to []string +func MustToStrings(arr any) []string { + ret, err := ToStrings(arr) + if err != nil { + panic(err) + } + return ret +} + +// StringsToSlice convert []string to []any +func StringsToSlice(ss []string) []any { + args := make([]any, len(ss)) + for i, s := range ss { + args[i] = s + } + return args +} + +// ToStrings convert any(allow: array,slice) to []string +func ToStrings(arr any) (ret []string, err error) { + rv := reflect.ValueOf(arr) + if rv.Kind() == reflect.String { + return []string{rv.String()}, nil + } + + if rv.Kind() != reflect.Slice && rv.Kind() != reflect.Array { + err = ErrInvalidType + return + } + + for i := 0; i < rv.Len(); i++ { + str, err := nstr.ToString(rv.Index(i).Interface()) + if err != nil { + return []string{}, err + } + + ret = append(ret, str) + } + return +} + +// SliceToStrings convert []any to []string +func SliceToStrings(arr []any) []string { + return QuietStrings(arr) +} + +// QuietStrings convert []any to []string +func QuietStrings(arr []any) []string { + ss := make([]string, len(arr)) + for i, v := range arr { + ss[i] = nstr.SafeString(v) + } + return ss +} + +// ConvType convert type of slice elements to new type slice, by the given newElemTyp type. +// +// Supports conversion between []string, []intX, []uintX, []floatX. +// +// Usage: +// +// ints, _ := narr.ConvType([]string{"12", "23"}, 1) // []int{12, 23} +func ConvType[T any, R any](arr []T, newElemTyp R) ([]R, error) { + newArr := make([]R, len(arr)) + elemTyp := reflect.TypeOf(newElemTyp) + + for i, elem := range arr { + var anyElem any = elem + // type is same. + if _, ok := anyElem.(R); ok { + newArr[i] = anyElem.(R) + continue + } + + // need conv type. + rfVal, err := nreflect.ValueByType(elem, elemTyp) + if err != nil { + return nil, err + } + newArr[i] = rfVal.Interface().(R) + } + return newArr, nil +} + +// AnyToString simple and quickly convert any array, slice to string +func AnyToString(arr any) string { + return NewFormatter(arr).Format() +} + +// SliceToString convert []any to string +func SliceToString(arr ...any) string { return ToString(arr) } + +// ToString simple and quickly convert []any to string +func ToString(arr []any) string { + // like fmt.Println([]any(nil)) + if arr == nil { + return "[]" + } + + var sb strings.Builder + sb.WriteByte('[') + + for i, v := range arr { + if i > 0 { + sb.WriteByte(',') + } + sb.WriteString(nstr.SafeString(v)) + } + + sb.WriteByte(']') + return sb.String() +} + +// CombineToMap combine two slice to map[K]V. +// +// If keys length is greater than values, the extra keys will be ignored. +func CombineToMap[K ndef.SortedType, V any](keys []K, values []V) map[K]V { + ln := len(values) + mp := make(map[K]V, len(keys)) + + for i, key := range keys { + if i >= ln { + break + } + mp[key] = values[i] + } + return mp +} + +// CombineToSMap combine two string-slice to map[string]string +func CombineToSMap(keys, values []string) map[string]string { + ln := len(values) + mp := make(map[string]string, len(keys)) + + for i, key := range keys { + if ln > i { + mp[key] = values[i] + } else { + mp[key] = "" + } + } + return mp +} diff --git a/narr/convert_test.go b/narr/convert_test.go new file mode 100644 index 0000000..4c11591 --- /dev/null +++ b/narr/convert_test.go @@ -0,0 +1,152 @@ +package narr_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestToInt64s(t *testing.T) { + is := assert.New(t) + + ints, err := narr.ToInt64s([]string{"1", "2"}) + is.Nil(err) + is.Eq("[]int64{1, 2}", fmt.Sprintf("%#v", ints)) + + ints = narr.MustToInt64s([]string{"1", "2"}) + is.Eq("[]int64{1, 2}", fmt.Sprintf("%#v", ints)) + + ints = narr.MustToInt64s([]any{"1", "2"}) + is.Eq("[]int64{1, 2}", fmt.Sprintf("%#v", ints)) + + ints = narr.SliceToInt64s([]any{"1", "2"}) + is.Eq("[]int64{1, 2}", fmt.Sprintf("%#v", ints)) + + _, err = narr.ToInt64s([]string{"a", "b"}) + is.Err(err) +} + +func TestToStrings(t *testing.T) { + is := assert.New(t) + + ss, err := narr.ToStrings([]int{1, 2}) + is.Nil(err) + is.Eq(`[]string{"1", "2"}`, fmt.Sprintf("%#v", ss)) + + ss = narr.MustToStrings([]int{1, 2}) + is.Eq(`[]string{"1", "2"}`, fmt.Sprintf("%#v", ss)) + + ss = narr.MustToStrings([]any{1, 2}) + is.Eq(`[]string{"1", "2"}`, fmt.Sprintf("%#v", ss)) + + ss = narr.SliceToStrings([]any{1, 2}) + is.Eq(`[]string{"1", "2"}`, fmt.Sprintf("%#v", ss)) + + as := narr.StringsToSlice([]string{"1", "2"}) + is.Eq(`[]interface {}{"1", "2"}`, fmt.Sprintf("%#v", as)) + + ss, err = narr.ToStrings("b") + is.Nil(err) + is.Eq(`[]string{"b"}`, fmt.Sprintf("%#v", ss)) + + _, err = narr.ToStrings([]any{[]int{1}, nil}) + is.Err(err) +} + +func TestStringsToString(t *testing.T) { + is := assert.New(t) + + is.Eq("a,b", narr.JoinStrings(",", []string{"a", "b"}...)) + is.Eq("a,b", narr.JoinStrings(",", []string{"a", "b"}...)) + is.Eq("a,b", narr.JoinStrings(",", "a", "b")) +} + +func TestAnyToString(t *testing.T) { + is := assert.New(t) + arr := [2]string{"a", "b"} + + is.Eq("", narr.AnyToString(nil)) + is.Eq("[]", narr.AnyToString([]string{})) + is.Eq("[a, b]", narr.AnyToString(arr)) + is.Eq("[a, b]", narr.AnyToString([]string{"a", "b"})) + is.Eq("", narr.AnyToString("invalid")) +} + +func TestSliceToString(t *testing.T) { + is := assert.New(t) + + is.Eq("[]", narr.SliceToString(nil)) + is.Eq("[a,b]", narr.SliceToString("a", "b")) +} + +func TestStringsToInts(t *testing.T) { + is := assert.New(t) + + ints, err := narr.StringsToInts([]string{"1", "2"}) + is.Nil(err) + is.Eq("[]int{1, 2}", fmt.Sprintf("%#v", ints)) + + _, err = narr.StringsToInts([]string{"a", "b"}) + is.Err(err) + + ints = narr.StringsAsInts([]string{"1", "2"}) + is.Eq("[]int{1, 2}", fmt.Sprintf("%#v", ints)) + is.Nil(narr.StringsAsInts([]string{"abc"})) +} + +func TestConvType(t *testing.T) { + is := assert.New(t) + + // []string => []int + arr, err := narr.ConvType([]string{"1", "2"}, 1) + is.Nil(err) + is.Eq("[]int{1, 2}", fmt.Sprintf("%#v", arr)) + + // []int => []string + arr1, err := narr.ConvType([]int{1, 2}, "1") + is.Nil(err) + is.Eq(`[]string{"1", "2"}`, fmt.Sprintf("%#v", arr1)) + + // not need conv + arr2, err := narr.ConvType([]string{"1", "2"}, "1") + is.Nil(err) + is.Eq(`[]string{"1", "2"}`, fmt.Sprintf("%#v", arr2)) + + // conv error + arr3, err := narr.ConvType([]string{"ab", "cd"}, 1) + is.Err(err) + is.Nil(arr3) +} + +func TestJoinSlice(t *testing.T) { + assert.Eq(t, "", narr.JoinSlice(",")) + assert.Eq(t, "", narr.JoinSlice(",", nil)) + assert.Eq(t, "a,23,b", narr.JoinSlice(",", "a", 23, "b")) +} + +func TestCombineToMap(t *testing.T) { + keys := []string{"key0", "key1"} + + mp := narr.CombineToMap(keys, []int{1, 2}) + assert.Len(t, mp, 2) + assert.Eq(t, 1, mp["key0"]) + assert.Eq(t, 2, mp["key1"]) + + mp = narr.CombineToMap(keys, []int{1}) + assert.Len(t, mp, 1) + assert.Eq(t, 1, mp["key0"]) +} + +func TestCombineToSMap(t *testing.T) { + keys := []string{"key0", "key1"} + + mp := narr.CombineToSMap(keys, []string{"val0", "val1"}) + assert.Len(t, mp, 2) + assert.Eq(t, "val0", mp["key0"]) + + mp = narr.CombineToSMap(keys, []string{"val0"}) + assert.Len(t, mp, 2) + assert.Eq(t, "val0", mp["key0"]) + assert.Eq(t, "", mp["key1"]) +} diff --git a/narr/format.go b/narr/format.go new file mode 100644 index 0000000..b0e4ef0 --- /dev/null +++ b/narr/format.go @@ -0,0 +1,124 @@ +package narr + +import ( + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nstr" + "io" + "reflect" +) + +// ArrFormatter struct +type ArrFormatter struct { + ndef.BaseFormatter + // Prefix string for each element + Prefix string + // Indent string for format each element + Indent string + // ClosePrefix string for last "]" + ClosePrefix string +} + +// NewFormatter instance +func NewFormatter(arr any) *ArrFormatter { + f := &ArrFormatter{} + f.Src = arr + + return f +} + +// WithFn for config self +func (f *ArrFormatter) WithFn(fn func(f *ArrFormatter)) *ArrFormatter { + fn(f) + return f +} + +// WithIndent string +func (f *ArrFormatter) WithIndent(indent string) *ArrFormatter { + f.Indent = indent + return f +} + +// FormatTo to custom buffer +func (f *ArrFormatter) FormatTo(w io.Writer) { + f.SetOutput(w) + f.doFormat() +} + +// Format to string +func (f *ArrFormatter) String() string { + f.Format() + return f.Format() +} + +// Format to string +func (f *ArrFormatter) Format() string { + f.doFormat() + return f.BsWriter().String() +} + +// Format to string +// +//goland:noinspection GoUnhandledErrorResult +func (f *ArrFormatter) doFormat() { + if f.Src == nil { + return + } + + rv, ok := f.Src.(reflect.Value) + if !ok { + rv = reflect.ValueOf(f.Src) + } + + rv = reflect.Indirect(rv) + if rv.Kind() != reflect.Slice && rv.Kind() != reflect.Array { + return + } + + writer := f.BsWriter() + arrLn := rv.Len() + if arrLn == 0 { + writer.WriteString("[]") + return + } + + // if f.AfterReset { + // defer f.Reset() + // } + + // sb.Grow(arrLn * 4) + writer.WriteByte('[') + + indentLn := len(f.Indent) + if indentLn > 0 { + writer.WriteByte('\n') + } + + for i := 0; i < arrLn; i++ { + if indentLn > 0 { + writer.WriteString(f.Indent) + } + writer.WriteString(nstr.SafeString(rv.Index(i).Interface())) + + if i < arrLn-1 { + writer.WriteByte(',') + + // no indent, with space + if indentLn == 0 { + writer.WriteByte(' ') + } + } + if indentLn > 0 { + writer.WriteByte('\n') + } + } + + if f.ClosePrefix != "" { + writer.WriteString(f.ClosePrefix) + } + writer.WriteByte(']') +} + +// FormatIndent array data to string. +func FormatIndent(arr any, indent string) string { + return NewFormatter(arr).WithIndent(indent).Format() +} diff --git a/narr/format_test.go b/narr/format_test.go new file mode 100644 index 0000000..ed89777 --- /dev/null +++ b/narr/format_test.go @@ -0,0 +1,23 @@ +package narr_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestNewFormatter(t *testing.T) { + arr := [2]string{"a", "b"} + str := narr.FormatIndent(arr, " ") + assert.Contains(t, str, "\n ") + fmt.Println(str) + + str = narr.FormatIndent(arr, "") + assert.NotContains(t, str, "\n ") + assert.Eq(t, "[a, b]", str) + fmt.Println(str) + + assert.Eq(t, "", narr.FormatIndent("invalid", "")) + assert.Eq(t, "[]", narr.FormatIndent([]string{}, "")) +} diff --git a/narr/types.go b/narr/types.go new file mode 100644 index 0000000..365ad4f --- /dev/null +++ b/narr/types.go @@ -0,0 +1,64 @@ +package narr + +import ( + "strconv" + "strings" +) + +// Ints type +type Ints []int + +// String to string +func (is Ints) String() string { + ss := make([]string, len(is)) + for i, iv := range is { + ss[i] = strconv.Itoa(iv) + } + return strings.Join(ss, ",") +} + +// Has given element +func (is Ints) Has(i int) bool { + for _, iv := range is { + if i == iv { + return true + } + } + return false +} + +// Strings type +type Strings []string + +// String to string +func (ss Strings) String() string { + return strings.Join(ss, ",") +} + +// Join to string +func (ss Strings) Join(sep string) string { + return strings.Join(ss, sep) +} + +// Has given element +func (ss Strings) Has(sub string) bool { + return ss.Contains(sub) +} + +// Contains given element +func (ss Strings) Contains(sub string) bool { + for _, s := range ss { + if s == sub { + return true + } + } + return false +} + +// First element value. +func (ss Strings) First() string { + if len(ss) > 0 { + return ss[0] + } + return "" +} diff --git a/narr/types_test.go b/narr/types_test.go new file mode 100644 index 0000000..fd132c2 --- /dev/null +++ b/narr/types_test.go @@ -0,0 +1,54 @@ +package narr_test + +import ( + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestInts_Has_String(t *testing.T) { + tests := []struct { + is narr.Ints + val int + want bool + want2 string + }{ + { + narr.Ints{12, 23}, + 12, + true, + "12,23", + }, + } + + for _, tt := range tests { + assert.Eq(t, tt.want, tt.is.Has(tt.val)) + assert.False(t, tt.is.Has(999)) + assert.Eq(t, tt.want2, tt.is.String()) + } +} + +func TestStrings_methods(t *testing.T) { + tests := []struct { + ss narr.Strings + val string + want bool + want2 string + }{ + { + narr.Strings{"a", "b"}, + "a", + true, + "a,b", + }, + } + + for _, tt := range tests { + assert.Eq(t, tt.want, tt.ss.Has(tt.val)) + assert.False(t, tt.ss.Has("not-exists")) + assert.Eq(t, tt.want2, tt.ss.String()) + } + + ss := narr.Strings{"a", "b"} + assert.Eq(t, "a b", ss.Join(" ")) +} diff --git a/narr/util.go b/narr/util.go new file mode 100644 index 0000000..8f86680 --- /dev/null +++ b/narr/util.go @@ -0,0 +1,130 @@ +package narr + +import ( + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nrandom" + "strings" +) + +// Reverse string slice [site user info 0] -> [0 info user site] +func Reverse(ss []string) { + ln := len(ss) + for i := 0; i < ln/2; i++ { + li := ln - i - 1 + ss[i], ss[li] = ss[li], ss[i] + } +} + +// StringsRemove a value form a string slice +func StringsRemove(ss []string, s string) []string { + ns := make([]string, 0, len(ss)) + for _, v := range ss { + if v != s { + ns = append(ns, v) + } + } + return ns +} + +// StringsFilter given strings, default will filter emtpy string. +// +// Usage: +// +// // output: [a, b] +// ss := narr.StringsFilter([]string{"a", "", "b", ""}) +func StringsFilter(ss []string, filter ...func(s string) bool) []string { + var fn func(s string) bool + if len(filter) > 0 && filter[0] != nil { + fn = filter[0] + } else { + fn = func(s string) bool { + return s != "" + } + } + + ns := make([]string, 0, len(ss)) + for _, s := range ss { + if fn(s) { + ns = append(ns, s) + } + } + return ns +} + +// StringsMap handle each string item, map to new strings +func StringsMap(ss []string, mapFn func(s string) string) []string { + ns := make([]string, 0, len(ss)) + for _, s := range ss { + ns = append(ns, mapFn(s)) + } + return ns +} + +// TrimStrings trim string slice item. +// +// Usage: +// +// // output: [a, b, c] +// ss := narr.TrimStrings([]string{",a", "b.", ",.c,"}, ",.") +func TrimStrings(ss []string, cutSet ...string) []string { + cutSetLn := len(cutSet) + hasCutSet := cutSetLn > 0 && cutSet[0] != "" + + var trimSet string + if hasCutSet { + trimSet = cutSet[0] + } + if cutSetLn > 1 { + trimSet = strings.Join(cutSet, "") + } + + ns := make([]string, 0, len(ss)) + for _, str := range ss { + if hasCutSet { + ns = append(ns, strings.Trim(str, trimSet)) + } else { + ns = append(ns, strings.TrimSpace(str)) + } + } + return ns +} + +// GetRandomOne get random element from an array/slice +func GetRandomOne[T any](arr []T) T { return RandomOne(arr) } + +// RandomOne get random element from an array/slice +func RandomOne[T any](arr []T) T { + if ln := len(arr); ln > 0 { + i := nrandom.RandInt(0, len(arr)) + return arr[i] + } + panic("cannot get value from nil or empty slice") +} + +// Unique value in the given slice data. +func Unique[T ~string | ndef.XIntOrFloat](list []T) []T { + if len(list) < 2 { + return list + } + + valMap := make(map[T]struct{}, len(list)) + uniArr := make([]T, 0, len(list)) + + for _, t := range list { + if _, ok := valMap[t]; !ok { + valMap[t] = struct{}{} + uniArr = append(uniArr, t) + } + } + return uniArr +} + +// IndexOf value in given slice. +func IndexOf[T ~string | ndef.XIntOrFloat](val T, list []T) int { + for i, v := range list { + if v == val { + return i + } + } + return -1 +} diff --git a/narr/util_test.go b/narr/util_test.go new file mode 100644 index 0000000..5d3d68e --- /dev/null +++ b/narr/util_test.go @@ -0,0 +1,128 @@ +package narr_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestReverse(t *testing.T) { + ss := []string{"a", "b", "c"} + + narr.Reverse(ss) + assert.Eq(t, []string{"c", "b", "a"}, ss) +} + +func TestStringsRemove(t *testing.T) { + ss := []string{"a", "b", "c"} + ns := narr.StringsRemove(ss, "b") + + assert.Contains(t, ns, "a") + assert.NotContains(t, ns, "b") + assert.Len(t, ns, 2) +} + +func TestStringsFilter(t *testing.T) { + is := assert.New(t) + + ss := narr.StringsFilter([]string{"a", "", "b", ""}) + is.Eq([]string{"a", "b"}, ss) +} + +func TestTrimStrings(t *testing.T) { + is := assert.New(t) + + // TrimStrings + ss := narr.TrimStrings([]string{" a", "b ", " c "}) + is.Eq("[a b c]", fmt.Sprint(ss)) + ss = narr.TrimStrings([]string{",a", "b.", ",.c,"}, ",.") + is.Eq("[a b c]", fmt.Sprint(ss)) + ss = narr.TrimStrings([]string{",a", "b.", ",.c,"}, ",", ".") + is.Eq("[a b c]", fmt.Sprint(ss)) +} + +func TestGetRandomOne(t *testing.T) { + is := assert.New(t) + // int slice + intSlice := []int{1, 2, 3, 4, 5, 6} + intVal := narr.GetRandomOne(intSlice) + intVal1 := narr.GetRandomOne(intSlice) + for intVal == intVal1 { + intVal1 = narr.GetRandomOne(intSlice) + } + + assert.IsType(t, 0, intVal) + is.True(narr.HasValue(intSlice, intVal)) + assert.IsType(t, 0, intVal1) + is.True(narr.HasValue(intSlice, intVal1)) + assert.NotEq(t, intVal, intVal1) + + // int array + intArray := []int{1, 2, 3, 4, 5, 6} + intReturned := narr.GetRandomOne(intArray) + intReturned1 := narr.GetRandomOne(intArray) + for intReturned == intReturned1 { + intReturned1 = narr.GetRandomOne(intArray) + } + assert.IsType(t, 0, intReturned) + is.True(narr.Contains(intArray, intReturned)) + assert.IsType(t, 0, intReturned1) + is.True(narr.Contains(intArray, intReturned1)) + assert.NotEq(t, intReturned, intReturned1) + + // string slice + strSlice := []string{"aa", "bb", "cc", "dd"} + strVal := narr.GetRandomOne(strSlice) + strVal1 := narr.GetRandomOne(strSlice) + for strVal == strVal1 { + strVal1 = narr.GetRandomOne(strSlice) + } + + assert.IsType(t, "", strVal) + is.True(narr.Contains(strSlice, strVal)) + assert.IsType(t, "", strVal1) + is.True(narr.Contains(strSlice, strVal1)) + assert.NotEq(t, strVal, strVal1) + + // string array + strArray := []string{"aa", "bb", "cc", "dd"} + strReturned := narr.GetRandomOne(strArray) + strReturned1 := narr.GetRandomOne(strArray) + for strReturned == strReturned1 { + strReturned1 = narr.GetRandomOne(strArray) + } + + assert.IsType(t, "", strReturned) + is.True(narr.Contains(strArray, strReturned)) + assert.IsType(t, "", strReturned1) + is.True(narr.Contains(strArray, strReturned1)) + assert.NotEq(t, strReturned, strReturned1) + + // byte slice + byteSlice := []byte("abcdefg") + byteVal := narr.GetRandomOne(byteSlice) + byteVal1 := narr.GetRandomOne(byteSlice) + for byteVal == byteVal1 { + byteVal1 = narr.GetRandomOne(byteSlice) + } + + assert.IsType(t, byte('a'), byteVal) + is.True(narr.Contains(byteSlice, byteVal)) + assert.IsType(t, byte('a'), byteVal1) + is.True(narr.Contains(byteSlice, byteVal1)) + assert.NotEq(t, byteVal, byteVal1) + + is.Panics(func() { + narr.RandomOne([]int{}) + }) +} + +func TestUnique(t *testing.T) { + assert.Eq(t, []int{2, 3, 4}, narr.Unique[int]([]int{2, 3, 2, 4})) + assert.Eq(t, []uint{2, 3, 4}, narr.Unique([]uint{2, 3, 2, 4})) + assert.Eq(t, []string{"ab", "bc", "cd"}, narr.Unique([]string{"ab", "bc", "ab", "cd"})) + + assert.Eq(t, 1, narr.IndexOf(3, []int{2, 3, 4})) + assert.Eq(t, -1, narr.IndexOf(5, []int{2, 3, 4})) +} diff --git a/nbyte/buffer.go b/nbyte/buffer.go new file mode 100644 index 0000000..20e67e3 --- /dev/null +++ b/nbyte/buffer.go @@ -0,0 +1,65 @@ +package nbyte + +import ( + "bytes" + "fmt" + "strings" +) + +// Buffer wrap and extends the bytes.Buffer +type Buffer struct { + bytes.Buffer +} + +// NewBuffer instance +func NewBuffer() *Buffer { + return &Buffer{} +} + +// WriteAny type value to buffer +func (b *Buffer) WriteAny(vs ...any) { + for _, v := range vs { + _, _ = b.Buffer.WriteString(fmt.Sprint(v)) + } +} + +// QuietWriteByte to buffer +func (b *Buffer) QuietWriteByte(c byte) { + _ = b.WriteByte(c) +} + +// QuietWritef write message to buffer +func (b *Buffer) QuietWritef(tpl string, vs ...any) { + _, _ = b.WriteString(fmt.Sprintf(tpl, vs...)) +} + +// Writeln write message to buffer with newline +func (b *Buffer) Writeln(ss ...string) { + b.QuietWriteln(ss...) +} + +// QuietWriteln write message to buffer with newline +func (b *Buffer) QuietWriteln(ss ...string) { + _, _ = b.WriteString(strings.Join(ss, "")) + _ = b.WriteByte('\n') +} + +// QuietWriteString to buffer +func (b *Buffer) QuietWriteString(ss ...string) { + _, _ = b.WriteString(strings.Join(ss, "")) +} + +// MustWriteString to buffer +func (b *Buffer) MustWriteString(ss ...string) { + _, err := b.WriteString(strings.Join(ss, "")) + if err != nil { + panic(err) + } +} + +// ResetAndGet buffer string. +func (b *Buffer) ResetAndGet() string { + s := b.String() + b.Reset() + return s +} diff --git a/nbyte/buffer_test.go b/nbyte/buffer_test.go new file mode 100644 index 0000000..e1b7307 --- /dev/null +++ b/nbyte/buffer_test.go @@ -0,0 +1,25 @@ +package nbyte_test + +import ( + "git.noahlan.cn/noahlan/ntool/nbyte" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestBuffer_WriteAny(t *testing.T) { + buf := nbyte.NewBuffer() + + buf.QuietWritef("ab-%s", "c") + buf.QuietWriteByte('d') + assert.Eq(t, "ab-cd", buf.ResetAndGet()) + + buf.QuietWriteString("ab", "-", "cd") + buf.MustWriteString("-ef") + assert.Eq(t, "ab-cd-ef", buf.ResetAndGet()) + + buf.WriteAny(23, "abc") + assert.Eq(t, "23abc", buf.ResetAndGet()) + + buf.Writeln("abc") + assert.Eq(t, "abc\n", buf.ResetAndGet()) +} diff --git a/nbyte/bytex.go b/nbyte/bytex.go new file mode 100644 index 0000000..0d3448a --- /dev/null +++ b/nbyte/bytex.go @@ -0,0 +1,18 @@ +package nbyte + +import ( + "crypto/md5" + "fmt" +) + +// Md5 Generate a 32-bit md5 bytes +func Md5(src any) []byte { + h := md5.New() + + if s, ok := src.(string); ok { + h.Write([]byte(s)) + } else { + h.Write([]byte(fmt.Sprint(src))) + } + return h.Sum(nil) +} diff --git a/nbyte/check.go b/nbyte/check.go new file mode 100644 index 0000000..2786aaa --- /dev/null +++ b/nbyte/check.go @@ -0,0 +1,62 @@ +package nbyte + +// IsLower checks if a character is lower case ('a' to 'z') +func IsLower(c byte) bool { + return 'a' <= c && c <= 'z' +} + +// ToLower converts a character 'A' to 'Z' to its lower case +func ToLower(c byte) byte { + if c >= 'A' && c <= 'Z' { + return c + 32 + } + return c +} + +// ToLowerAll converts a character 'A' to 'Z' to its lower case +func ToLowerAll(bs []byte) []byte { + for i := range bs { + bs[i] = ToLower(bs[i]) + } + return bs +} + +// IsUpper checks if a character is upper case ('A' to 'Z') +func IsUpper(c byte) bool { + return 'A' <= c && c <= 'Z' +} + +// ToUpper converts a character 'a' to 'z' to its upper case +func ToUpper(r byte) byte { + if r >= 'a' && r <= 'z' { + return r - 32 + } + return r +} + +// ToUpperAll converts a character 'a' to 'z' to its upper case +func ToUpperAll(rs []byte) []byte { + for i := range rs { + rs[i] = ToUpper(rs[i]) + } + return rs +} + +// IsDigit checks if a character is digit ('0' to '9') +func IsDigit(r byte) bool { + return r >= '0' && r <= '9' +} + +// IsAlphabet byte +func IsAlphabet(char byte) bool { + // A 65 -> Z 90 + if char >= 'A' && char <= 'Z' { + return true + } + + // a 97 -> z 122 + if char >= 'a' && char <= 'z' { + return true + } + return false +} diff --git a/nbyte/encoder.go b/nbyte/encoder.go new file mode 100644 index 0000000..ea79bfa --- /dev/null +++ b/nbyte/encoder.go @@ -0,0 +1,63 @@ +package nbyte + +import ( + "encoding/base64" + "encoding/hex" +) + +// BytesEncoder interface +type BytesEncoder interface { + Encode(src []byte) []byte + Decode(src []byte) ([]byte, error) +} + +// StdEncoder implement the BytesEncoder +type StdEncoder struct { + encodeFn func(src []byte) []byte + decodeFn func(src []byte) ([]byte, error) +} + +// NewStdEncoder instance +func NewStdEncoder(encFn func(src []byte) []byte, decFn func(src []byte) ([]byte, error)) *StdEncoder { + return &StdEncoder{ + encodeFn: encFn, + decodeFn: decFn, + } +} + +// Encode input +func (e *StdEncoder) Encode(src []byte) []byte { + return e.encodeFn(src) +} + +// Decode input +func (e *StdEncoder) Decode(src []byte) ([]byte, error) { + return e.decodeFn(src) +} + +var ( + // HexEncoder instance + HexEncoder = NewStdEncoder(func(src []byte) []byte { + dst := make([]byte, hex.EncodedLen(len(src))) + hex.Encode(dst, src) + return dst + }, func(src []byte) ([]byte, error) { + n, err := hex.Decode(src, src) + return src[:n], err + }) + + // B64Encoder instance + B64Encoder = NewStdEncoder(func(src []byte) []byte { + b64Dst := make([]byte, base64.StdEncoding.EncodedLen(len(src))) + base64.StdEncoding.Encode(b64Dst, src) + return b64Dst + }, func(src []byte) ([]byte, error) { + dBuf := make([]byte, base64.StdEncoding.DecodedLen(len(src))) + n, err := base64.StdEncoding.Decode(dBuf, src) + if err != nil { + return nil, err + } + + return dBuf[:n], err + }) +) diff --git a/nbyte/encoder_test.go b/nbyte/encoder_test.go new file mode 100644 index 0000000..603746d --- /dev/null +++ b/nbyte/encoder_test.go @@ -0,0 +1,27 @@ +package nbyte_test + +import ( + "git.noahlan.cn/noahlan/ntool/nbyte" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestB64Encoder(t *testing.T) { + src := []byte("abc1234566") + dst := nbyte.B64Encoder.Encode(src) + assert.NotEmpty(t, dst) + + decSrc, err := nbyte.B64Encoder.Decode(dst) + assert.NoError(t, err) + assert.Eq(t, src, decSrc) +} + +func TestHexEncoder(t *testing.T) { + src := []byte("abc1234566") + dst := nbyte.HexEncoder.Encode(src) + assert.NotEmpty(t, dst) + + decSrc, err := nbyte.HexEncoder.Decode(dst) + assert.NoError(t, err) + assert.Eq(t, src, decSrc) +} diff --git a/nbyte/util.go b/nbyte/util.go new file mode 100644 index 0000000..7d5478c --- /dev/null +++ b/nbyte/util.go @@ -0,0 +1,104 @@ +package nbyte + +import ( + "bytes" + "fmt" + "strconv" + "time" + "unsafe" +) + +// FirstLine from command output +func FirstLine(bs []byte) []byte { + if i := bytes.IndexByte(bs, '\n'); i >= 0 { + return bs[0:i] + } + return bs +} + +// StrOrErr convert to string, return empty string on error. +func StrOrErr(bs []byte, err error) (string, error) { + if err != nil { + return "", err + } + return string(bs), err +} + +// SafeString convert to string, return empty string on error. +func SafeString(bs []byte, err error) string { + if err != nil { + return "" + } + return string(bs) +} + +// String unsafe convert bytes to string +func String(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// ToString convert bytes to string +func ToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// AppendAny append any value to byte slice +func AppendAny(dst []byte, v any) []byte { + if v == nil { + return append(dst, ""...) + } + + switch val := v.(type) { + case []byte: + dst = append(dst, val...) + case string: + dst = append(dst, val...) + case int: + dst = strconv.AppendInt(dst, int64(val), 10) + case int8: + dst = strconv.AppendInt(dst, int64(val), 10) + case int16: + dst = strconv.AppendInt(dst, int64(val), 10) + case int32: + dst = strconv.AppendInt(dst, int64(val), 10) + case int64: + dst = strconv.AppendInt(dst, val, 10) + case uint: + dst = strconv.AppendUint(dst, uint64(val), 10) + case uint8: + dst = strconv.AppendUint(dst, uint64(val), 10) + case uint16: + dst = strconv.AppendUint(dst, uint64(val), 10) + case uint32: + dst = strconv.AppendUint(dst, uint64(val), 10) + case uint64: + dst = strconv.AppendUint(dst, val, 10) + case float32: + dst = strconv.AppendFloat(dst, float64(val), 'f', -1, 32) + case float64: + dst = strconv.AppendFloat(dst, val, 'f', -1, 64) + case bool: + dst = strconv.AppendBool(dst, val) + case time.Time: + dst = val.AppendFormat(dst, time.RFC3339) + case time.Duration: + dst = strconv.AppendInt(dst, int64(val), 10) + case error: + dst = append(dst, val.Error()...) + case fmt.Stringer: + dst = append(dst, val.String()...) + default: + dst = append(dst, fmt.Sprint(v)...) + } + return dst +} + +// Cut bytes. like the strings.Cut() +func Cut(bs []byte, sep byte) (before, after []byte, found bool) { + if i := bytes.IndexByte(bs, sep); i >= 0 { + return bs[:i], bs[i+1:], true + } + + before = bs + return +} diff --git a/nbyte/util_test.go b/nbyte/util_test.go new file mode 100644 index 0000000..879c330 --- /dev/null +++ b/nbyte/util_test.go @@ -0,0 +1,55 @@ +package nbyte_test + +import ( + "errors" + "git.noahlan.cn/noahlan/ntool/nbyte" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "git.noahlan.cn/noahlan/ntool/ntime" + "testing" +) + +func TestFirstLine(t *testing.T) { + bs := []byte("hi\ninhere") + assert.Eq(t, []byte("hi"), nbyte.FirstLine(bs)) + assert.Eq(t, []byte("hi"), nbyte.FirstLine([]byte("hi"))) +} + +func TestStrOrErr(t *testing.T) { + bs := []byte("hi, inhere") + assert.Eq(t, "hi, inhere", nbyte.SafeString(bs, nil)) + assert.Eq(t, "", nbyte.SafeString(bs, errors.New("error"))) + + str, err := nbyte.StrOrErr(bs, nil) + assert.NoErr(t, err) + assert.Eq(t, "hi, inhere", str) + + str, err = nbyte.StrOrErr(bs, errors.New("error")) + assert.Err(t, err) + assert.Eq(t, "", str) +} + +func TestMd5(t *testing.T) { + assert.NotEmpty(t, nbyte.Md5("abc")) + assert.NotEmpty(t, nbyte.Md5([]int{12, 34})) +} + +func TestAppendAny(t *testing.T) { + assert.Eq(t, []byte("123"), nbyte.AppendAny(nil, 123)) + assert.Eq(t, []byte("123"), nbyte.AppendAny([]byte{}, 123)) + assert.Eq(t, []byte("123"), nbyte.AppendAny([]byte("1"), 23)) + assert.Eq(t, []byte("1"), nbyte.AppendAny([]byte("1"), nil)) + assert.Eq(t, "3600000000000", string(nbyte.AppendAny([]byte{}, ntime.OneHour))) +} + +func TestCut(t *testing.T) { + // test for nbyte.Cut() + b, a, ok := nbyte.Cut([]byte("age=123"), '=') + assert.True(t, ok) + assert.Eq(t, []byte("age"), b) + assert.Eq(t, []byte("123"), a) + + b, a, ok = nbyte.Cut([]byte("age=123"), 'x') + assert.False(t, ok) + assert.Eq(t, []byte("age=123"), b) + assert.Empty(t, a) +} diff --git a/ncli/cmdline/builder.go b/ncli/cmdline/builder.go new file mode 100644 index 0000000..f554237 --- /dev/null +++ b/ncli/cmdline/builder.go @@ -0,0 +1,84 @@ +package cmdline + +import ( + "git.noahlan.cn/noahlan/ntool/nstr" + "strings" +) + +// LineBuilder build command line string. +// codes refer from strings.Builder +type LineBuilder struct { + strings.Builder +} + +// NewBuilder create +func NewBuilder(binFile string, args ...string) *LineBuilder { + b := &LineBuilder{} + + if binFile != "" { + b.AddArg(binFile) + } + + b.AddArray(args) + return b +} + +// AddArg to builder +func (b *LineBuilder) AddArg(arg string) { + _, _ = b.WriteString(arg) +} + +// AddArgs to builder +func (b *LineBuilder) AddArgs(args ...string) { + b.AddArray(args) +} + +// AddArray to builder +func (b *LineBuilder) AddArray(args []string) { + for _, arg := range args { + _, _ = b.WriteString(arg) + } +} + +// AddAny args to builder +func (b *LineBuilder) AddAny(args ...any) { + for _, arg := range args { + _, _ = b.WriteString(nstr.SafeString(arg)) + } +} + +// WriteString arg string to the builder, will auto quote special string. +// refer strconv.Quote() +func (b *LineBuilder) WriteString(a string) (int, error) { + var quote byte + if pos := strings.IndexByte(a, '"'); pos > -1 { + quote = '\'' + // fix: a = `--pretty=format:"one two three"` + if pos > 0 && '"' == a[len(a)-1] { + quote = 0 + } + } else if pos := strings.IndexByte(a, '\''); pos > -1 { + quote = '"' + // fix: a = "--pretty=format:'one two three'" + if pos > 0 && '\'' == a[len(a)-1] { + quote = 0 + } + } else if a == "" || strings.ContainsRune(a, ' ') { + quote = '"' + } + + // add sep on not-first write. + if b.Len() != 0 { + _ = b.WriteByte(' ') + } + + // no quote char OR not need quote + if quote == 0 { + return b.Builder.WriteString(a) + } + + _ = b.WriteByte(quote) // add start quote + n, err := b.Builder.WriteString(a) + _ = b.WriteByte(quote) // add end quote + return n, err +} diff --git a/ncli/cmdline/builder_test.go b/ncli/cmdline/builder_test.go new file mode 100644 index 0000000..229a6f3 --- /dev/null +++ b/ncli/cmdline/builder_test.go @@ -0,0 +1,52 @@ +package cmdline_test + +import ( + "git.noahlan.cn/noahlan/ntool/ncli/cmdline" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestLineBuild(t *testing.T) { + s := cmdline.LineBuild("myapp", []string{"-a", "val0", "arg0"}) + + assert.Eq(t, "myapp -a val0 arg0", s) + + // case: empty string + b := cmdline.NewBuilder("myapp", "-a", "") + + assert.Eq(t, 11, b.Len()) + assert.Eq(t, `myapp -a ""`, b.String()) + + b.Reset() + assert.Eq(t, 0, b.Len()) + + // case: add first + b.AddArg("myapp") + assert.Eq(t, `myapp`, b.String()) + + b.AddArgs("-a", "val0") + assert.Eq(t, "myapp -a val0", b.String()) + + // case: contains `"` + b.Reset() + b.AddArgs("myapp", "-a", `"val0"`) + assert.Eq(t, `myapp -a '"val0"'`, b.String()) + b.Reset() + b.AddArgs("myapp", "-a", `the "val0" of option`) + assert.Eq(t, `myapp -a 'the "val0" of option'`, b.String()) + + // case: contains `'` + b.Reset() + b.AddArgs("myapp", "-a", `'val0'`) + assert.Eq(t, `myapp -a "'val0'"`, b.String()) + b.Reset() + b.AddArgs("myapp", "-a", `the 'val0' of option`) + assert.Eq(t, `myapp -a "the 'val0' of option"`, b.String()) +} + +func TestLineBuild_hasQuote(t *testing.T) { + line := "git log --pretty=format:'one two three'" + args := cmdline.ParseLine(line) + // dump.P(args) + assert.Eq(t, line, cmdline.LineBuild("", args)) +} diff --git a/ncli/cmdline/cmdline.go b/ncli/cmdline/cmdline.go new file mode 100644 index 0000000..dab5ce1 --- /dev/null +++ b/ncli/cmdline/cmdline.go @@ -0,0 +1,41 @@ +package cmdline + +import ( + "fmt" + "strings" +) + +// LineBuild build command line string by given args. +func LineBuild(binFile string, args []string) string { + return NewBuilder(binFile, args...).String() +} + +// ParseLine input command line text. alias of the StringToOSArgs() +func ParseLine(line string) []string { + return NewParser(line).Parse() +} + +// Cmdline build +func Cmdline(args []string, binName ...string) string { + b := new(strings.Builder) + + if len(binName) > 0 { + b.WriteString(binName[0]) + b.WriteByte(' ') + } + + for i, a := range args { + if i > 0 { + b.WriteByte(' ') + } + + if strings.ContainsRune(a, '"') { + b.WriteString(fmt.Sprintf(`'%s'`, a)) + } else if a == "" || strings.ContainsRune(a, '\'') || strings.ContainsRune(a, ' ') { + b.WriteString(fmt.Sprintf(`"%s"`, a)) + } else { + b.WriteString(a) + } + } + return b.String() +} diff --git a/ncli/cmdline/parser.go b/ncli/cmdline/parser.go new file mode 100644 index 0000000..36ac92c --- /dev/null +++ b/ncli/cmdline/parser.go @@ -0,0 +1,173 @@ +package cmdline + +import ( + "bytes" + "git.noahlan.cn/noahlan/ntool/internal/common" + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nstr" + "os/exec" + "strings" +) + +// LineParser struct +// parse input command line to []string, such as cli os.Args +type LineParser struct { + parsed bool + // Line the full input command line text + // eg `kite top sub -a "this is a message" --foo val1 --bar "val 2"` + Line string + // ParseEnv parse ENV var on the line. + ParseEnv bool + // the exploded nodes by space. + nodes []string + // the parsed args + args []string + + // temp value + quoteChar byte + quoteIndex int // if > 0, mark is not on start + tempNode bytes.Buffer +} + +// NewParser create +func NewParser(line string) *LineParser { + return &LineParser{Line: line} +} + +// WithParseEnv with parse ENV var +func (p *LineParser) WithParseEnv() *LineParser { + p.ParseEnv = true + return p +} + +// AlsoEnvParse input command line text to os.Args, will parse ENV var +func (p *LineParser) AlsoEnvParse() []string { + p.ParseEnv = true + return p.Parse() +} + +// NewExecCmd quick create exec.Cmd by cmdline string +func (p *LineParser) NewExecCmd() *exec.Cmd { + // parse get bin and args + binName, args := p.BinAndArgs() + + // create a new Cmd instance + return exec.Command(binName, args...) +} + +// BinAndArgs get binName and args +func (p *LineParser) BinAndArgs() (bin string, args []string) { + p.Parse() // ensure parsed. + + ln := len(p.args) + if ln == 0 { + return + } + + bin = p.args[0] + if ln > 1 { + args = p.args[1:] + } + return +} + +// Parse input command line text to os.Args +func (p *LineParser) Parse() []string { + if p.parsed { + return p.args + } + + p.parsed = true + p.Line = strings.TrimSpace(p.Line) + if p.Line == "" { + return p.args + } + + // enable parse Env var + if p.ParseEnv { + p.Line = common.ParseEnvVar(p.Line, nil) + } + + p.nodes = strings.Split(p.Line, " ") + if len(p.nodes) == 1 { + p.args = p.nodes + return p.args + } + + for i := 0; i < len(p.nodes); i++ { + node := p.nodes[i] + if node == "" { + continue + } + + p.parseNode(node) + } + + p.nodes = p.nodes[:0] + if p.tempNode.Len() > 0 { + p.appendTempNode() + } + return p.args +} + +func (p *LineParser) parseNode(node string) { + maxIdx := len(node) - 1 + start, end := node[0], node[maxIdx] + + // in quotes + if p.quoteChar != 0 { + p.tempNode.WriteByte(' ') + + // end quotes + if end == p.quoteChar { + if p.quoteIndex > 0 { + p.tempNode.WriteString(node) // eg: node="--pretty=format:'one two'" + } else { + p.tempNode.WriteString(node[:maxIdx]) // remove last quote + } + p.appendTempNode() + } else { // goon ... write to temp node + p.tempNode.WriteString(node) + } + return + } + + // quote start + if start == ndef.DoubleQuote || start == ndef.SingleQuote { + // only one words. eg: `-m "msg"` + if end == start { + p.args = append(p.args, node[1:maxIdx]) + return + } + + p.quoteChar = start + p.tempNode.WriteString(node[1:]) + } else if end == ndef.DoubleQuote || end == ndef.SingleQuote { + p.args = append(p.args, node) // only one node: `msg"` + } else { + // eg: --pretty=format:'one two three' + if nstr.ContainsByte(node, ndef.DoubleQuote) { + p.quoteIndex = 1 // mark is not on start + p.quoteChar = ndef.DoubleQuote + } else if nstr.ContainsByte(node, ndef.SingleQuote) { + p.quoteIndex = 1 + p.quoteChar = ndef.SingleQuote + } + + // in quote, append to temp-node + if p.quoteChar != 0 { + p.tempNode.WriteString(node) + } else { + p.args = append(p.args, node) + } + } +} + +func (p *LineParser) appendTempNode() { + p.args = append(p.args, p.tempNode.String()) + + // reset context value + p.quoteChar = 0 + p.quoteIndex = 0 + p.tempNode.Reset() +} diff --git a/ncli/cmdline/parser_test.go b/ncli/cmdline/parser_test.go new file mode 100644 index 0000000..a1df6f9 --- /dev/null +++ b/ncli/cmdline/parser_test.go @@ -0,0 +1,145 @@ +package cmdline_test + +import ( + "git.noahlan.cn/noahlan/ntool/ncli/cmdline" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "git.noahlan.cn/noahlan/ntool/ntest/mock" + "strings" + "testing" +) + +func TestLineParser_Parse(t *testing.T) { + args := cmdline.NewParser(`./app top sub -a ddd --xx "msg"`).Parse() + assert.Len(t, args, 7) + assert.Eq(t, "msg", args[6]) + + args = cmdline.ParseLine(" ") + assert.Len(t, args, 0) + + args = cmdline.ParseLine("./app") + assert.Len(t, args, 1) + + p := cmdline.NewParser("./app sub ${A_ENV_VAR}") + p.WithParseEnv() + assert.True(t, p.ParseEnv) + + mock.MockEnvValue("A_ENV_VAR", "env-value", func(nv string) { + bin, args := p.BinAndArgs() + assert.Len(t, args, 2) + assert.Eq(t, "./app", bin) + assert.Eq(t, "env-value", args[1]) + + assert.NotEmpty(t, p.NewExecCmd()) + }) + + p = cmdline.NewParser("./app sub ${A_ENV_VAR2}") + mock.MockEnvValue("A_ENV_VAR2", "env-value2", func(nv string) { + args := p.AlsoEnvParse() + assert.Len(t, args, 3) + assert.Eq(t, "env-value2", args[2]) + }) +} + +func TestParseLine_Parse_withQuote(t *testing.T) { + tests := []struct { + line string + argN int + index int + value string + }{ + { + line: `./app top sub -a ddd --xx "abc +def"`, + argN: 7, index: 6, value: "abc\ndef", + }, + { + line: `./app top sub -a ddd --xx "abc +def ghi"`, + argN: 7, index: 6, value: "abc\ndef ghi", + }, + { + line: `./app top sub --msg "has multi words"`, + argN: 5, index: 4, value: "has multi words", + }, + { + line: `./app top sub --msg "has inner 'quote'"`, + argN: 5, index: 4, value: "has inner 'quote'", + }, + { + line: `./app top sub --msg "'has' inner quote"`, + argN: 5, index: 4, value: "'has' inner quote", + }, + { + line: `./app top sub --msg "has inner 'quote' words"`, + argN: 5, index: 4, value: "has inner 'quote' words", + }, + { + line: `./app top sub --msg "has 'inner quote' words"`, + argN: 5, index: 4, value: "has 'inner quote' words", + }, + { + line: `./app top sub --msg "has 'inner quote words'"`, + argN: 5, index: 4, value: "has 'inner quote words'", + }, + { + line: `./app top sub --msg "'has inner quote' words"`, + argN: 5, index: 4, value: "'has inner quote' words", + }, + } + + for _, tt := range tests { + args := cmdline.NewParser(tt.line).Parse() + assert.Len(t, args, tt.argN) + assert.Eq(t, tt.value, args[tt.index]) + } +} + +func TestParseLine_longLine(t *testing.T) { + line := "git log --pretty=format:'one two three'" + args := cmdline.ParseLine(line) + assert.Len(t, args, 3) + assert.Eq(t, "--pretty=format:'one two three'", args[2]) + + line = `git log --pretty=format:"one two three""` + args = cmdline.ParseLine(line) + assert.Len(t, args, 3) + assert.Eq(t, `--pretty=format:"one two three""`, args[2]) + + line = "git log --color --graph --pretty=format:'%Cred%h%Creset:%C(ul yellow)%d%Creset %s (%Cgreen%cr%Creset, %C(bold blue)%an%Creset)' --abbrev-commit -10" + args = cmdline.ParseLine(line) + //dump.P(args) + assert.Len(t, args, 7) + assert.Eq(t, "--graph", args[3]) + assert.Eq(t, "--abbrev-commit", args[5]) +} + +func TestParseLine_errLine(t *testing.T) { + // exception line string. + args := cmdline.NewParser(`./app top sub -a ddd --xx msg"`).Parse() + assert.Len(t, args, 7) + assert.Eq(t, "msg\"", args[6]) + + args = cmdline.ParseLine(`./app top sub -a ddd --xx "msg`) + assert.Len(t, args, 7) + assert.Eq(t, "msg", args[6]) + + args = cmdline.ParseLine(`./app top sub -a ddd --xx "msg text`) + assert.Len(t, args, 7) + assert.Eq(t, "msg text", args[6]) + + args = cmdline.ParseLine(`./app top sub -a ddd --xx "msg "text"`) + assert.Len(t, args, 7) + assert.Eq(t, "msg \"text", args[6]) +} + +func TestLineParser_BinAndArgs(t *testing.T) { + p := cmdline.NewParser("git status") + b, a := p.BinAndArgs() + assert.Eq(t, "git", b) + assert.Eq(t, "status", strings.Join(a, " ")) + + p = cmdline.NewParser("git") + b, a = p.BinAndArgs() + assert.Eq(t, "git", b) + assert.Empty(t, a) +} diff --git a/ncli/color_print.go b/ncli/color_print.go new file mode 100644 index 0000000..f6b1761 --- /dev/null +++ b/ncli/color_print.go @@ -0,0 +1,110 @@ +package ncli + +import "github.com/gookit/color" + +/************************************************************* + * quick use color print message + *************************************************************/ + +// Redp print message with Red color +func Redp(a ...any) { color.Red.Print(a...) } + +// Redf print message with Red color +func Redf(format string, a ...any) { color.Red.Printf(format, a...) } + +// Redln print message line with Red color +func Redln(a ...any) { color.Red.Println(a...) } + +// Bluep print message with Blue color +func Bluep(a ...any) { color.Blue.Print(a...) } + +// Bluef print message with Blue color +func Bluef(format string, a ...any) { color.Blue.Printf(format, a...) } + +// Blueln print message line with Blue color +func Blueln(a ...any) { color.Blue.Println(a...) } + +// Cyanp print message with Cyan color +func Cyanp(a ...any) { color.Cyan.Print(a...) } + +// Cyanf print message with Cyan color +func Cyanf(format string, a ...any) { color.Cyan.Printf(format, a...) } + +// Cyanln print message line with Cyan color +func Cyanln(a ...any) { color.Cyan.Println(a...) } + +// Grayp print message with gray color +func Grayp(a ...any) { color.Gray.Print(a...) } + +// Grayf print message with gray color +func Grayf(format string, a ...any) { color.Gray.Printf(format, a...) } + +// Grayln print message line with gray color +func Grayln(a ...any) { color.Gray.Println(a...) } + +// Greenp print message with green color +func Greenp(a ...any) { color.Green.Print(a...) } + +// Greenf print message with green color +func Greenf(format string, a ...any) { color.Green.Printf(format, a...) } + +// Greenln print message line with green color +func Greenln(a ...any) { color.Green.Println(a...) } + +// Yellowp print message with yellow color +func Yellowp(a ...any) { color.Yellow.Print(a...) } + +// Yellowf print message with yellow color +func Yellowf(format string, a ...any) { color.Yellow.Printf(format, a...) } + +// Yellowln print message line with yellow color +func Yellowln(a ...any) { color.Yellow.Println(a...) } + +// Magentap print message with magenta color +func Magentap(a ...any) { color.Magenta.Print(a...) } + +// Magentaf print message with magenta color +func Magentaf(format string, a ...any) { color.Magenta.Printf(format, a...) } + +// Magentaln print message line with magenta color +func Magentaln(a ...any) { color.Magenta.Println(a...) } + +/************************************************************* + * quick use style print message + *************************************************************/ + +// Infop print message with info color +func Infop(a ...any) { color.Info.Print(a...) } + +// Infof print message with info style +func Infof(format string, a ...any) { color.Info.Printf(format, a...) } + +// Infoln print message with info style +func Infoln(a ...any) { color.Info.Println(a...) } + +// Successp print message with success color +func Successp(a ...any) { color.Success.Print(a...) } + +// Successf print message with success style +func Successf(format string, a ...any) { color.Success.Printf(format, a...) } + +// Successln print message with success style +func Successln(a ...any) { color.Success.Println(a...) } + +// Errorp print message with error color +func Errorp(a ...any) { color.Error.Print(a...) } + +// Errorf print message with error style +func Errorf(format string, a ...any) { color.Error.Printf(format, a...) } + +// Errorln print message with error style +func Errorln(a ...any) { color.Error.Println(a...) } + +// Warnp print message with warn color +func Warnp(a ...any) { color.Warn.Print(a...) } + +// Warnf print message with warn style +func Warnf(format string, a ...any) { color.Warn.Printf(format, a...) } + +// Warnln print message with warn style +func Warnln(a ...any) { color.Warn.Println(a...) } diff --git a/ncli/info.go b/ncli/info.go new file mode 100644 index 0000000..be47dd1 --- /dev/null +++ b/ncli/info.go @@ -0,0 +1,54 @@ +package ncli + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "golang.org/x/term" + "os" + "path" +) + +// Workdir get +func Workdir() string { + return common.Workdir() +} + +// BinDir get +func BinDir() string { + return path.Dir(os.Args[0]) +} + +// BinFile get +func BinFile() string { + return os.Args[0] +} + +// BinName get +func BinName() string { + return path.Base(os.Args[0]) +} + +// exec: `stty -a 2>&1` +// const ( +// mac: speed 9600 baud; 97 rows; 362 columns; +// macSttyMsgPattern = `(\d+)\s+rows;\s*(\d+)\s+columns;` +// linux: speed 38400 baud; rows 97; columns 362; line = 0; +// linuxSttyMsgPattern = `rows\s+(\d+);\s*columns\s+(\d+);` +// ) +var terminalWidth, terminalHeight int + +// GetTermSize for current console terminal. +func GetTermSize(refresh ...bool) (w int, h int) { + if terminalWidth > 0 && len(refresh) > 0 && !refresh[0] { + return terminalWidth, terminalHeight + } + + var err error + w, h, err = term.GetSize(syscallStdinFd()) + if err != nil { + return + } + + // cache result + terminalWidth, terminalHeight = w, h + return +} diff --git a/ncli/info_nonwin.go b/ncli/info_nonwin.go new file mode 100644 index 0000000..546e346 --- /dev/null +++ b/ncli/info_nonwin.go @@ -0,0 +1,11 @@ +//go:build !windows + +package ncli + +import ( + "syscall" +) + +func syscallStdinFd() int { + return syscall.Stdin +} diff --git a/ncli/info_windows.go b/ncli/info_windows.go new file mode 100644 index 0000000..46ecbad --- /dev/null +++ b/ncli/info_windows.go @@ -0,0 +1,10 @@ +//go:build windows + +package ncli + +import "syscall" + +// on Windows, must convert 'syscall.Stdin' to int +func syscallStdinFd() int { + return int(syscall.Stdin) +} diff --git a/ncli/read.go b/ncli/read.go new file mode 100644 index 0000000..7704cbc --- /dev/null +++ b/ncli/read.go @@ -0,0 +1,142 @@ +package ncli + +import ( + "bufio" + "io" + "os" + "strings" + + "github.com/gookit/color" + "golang.org/x/term" +) + +// the global input output stream +var ( + // Input global input stream + Input io.Reader = os.Stdin + // Output global output stream + Output io.Writer = os.Stdout +) + +// ReadInput read user input form Stdin +func ReadInput(question string) (string, error) { + if len(question) > 0 { + color.Fprint(Output, question) + } + + scanner := bufio.NewScanner(Input) + if !scanner.Scan() { // reading + return "", scanner.Err() + } + + answer := scanner.Text() + return strings.TrimSpace(answer), nil +} + +// ReadLine read one line from user input. +// +// Usage: +// +// in := ncli.ReadLine("") +// ans, _ := ncli.ReadLine("your name?") +func ReadLine(question string) (string, error) { + if len(question) > 0 { + color.Fprint(Output, question) + } + + reader := bufio.NewReader(Input) + answer, _, err := reader.ReadLine() + return strings.TrimSpace(string(answer)), err +} + +// ReadFirst read first char +// +// Usage: +// +// ans, _ := ncli.ReadFirst("continue?[y/n] ") +func ReadFirst(question string) (string, error) { + answer, err := ReadFirstByte(question) + return string(answer), err +} + +// ReadFirstByte read first byte char +// +// Usage: +// +// ans, _ := ncli.ReadFirstByte("continue?[y/n] ") +func ReadFirstByte(question string) (byte, error) { + if len(question) > 0 { + color.Fprint(Output, question) + } + + reader := bufio.NewReader(Input) + return reader.ReadByte() +} + +// ReadFirstRune read first rune char +func ReadFirstRune(question string) (rune, error) { + if len(question) > 0 { + color.Fprint(Output, question) + } + + reader := bufio.NewReader(Input) + answer, _, err := reader.ReadRune() + return answer, err +} + +// ReadAsBool check user inputted answer is right +// +// Usage: +// +// ok := ReadAsBool("are you OK? [y/N]", false) +func ReadAsBool(tip string, defVal bool) bool { + fChar, err := ReadFirstByte(tip) + if err != nil { + panic(err) + } + + if fChar != 0 { + return ByteIsYes(fChar) + } + return defVal +} + +// ReadPassword from console terminal +func ReadPassword(question ...string) string { + if len(question) > 0 { + print(question[0]) + } else { + print("Enter Password: ") + } + + bs, err := term.ReadPassword(syscallStdinFd()) + if err != nil { + return "" + } + + println() // new line + return string(bs) +} + +// Confirm with user input +func Confirm(tip string, defVal ...bool) bool { + mark := " [y/N]: " + + var defV bool + if len(defVal) > 0 && defVal[0] { + defV = true + mark = " [Y/n]: " + } + + return ReadAsBool(tip+mark, defV) +} + +// InputIsYes answer: yes, y, Yes, Y +func InputIsYes(ans string) bool { + return len(ans) > 0 && (ans[0] == 'y' || ans[0] == 'Y') +} + +// ByteIsYes answer: yes, y, Yes, Y +func ByteIsYes(ans byte) bool { + return ans == 'y' || ans == 'Y' +} diff --git a/ncli/read_nonwin.go b/ncli/read_nonwin.go new file mode 100644 index 0000000..11c317d --- /dev/null +++ b/ncli/read_nonwin.go @@ -0,0 +1,3 @@ +//go:build !windows + +package ncli diff --git a/ncli/read_test.go b/ncli/read_test.go new file mode 100644 index 0000000..c7c54ae --- /dev/null +++ b/ncli/read_test.go @@ -0,0 +1,54 @@ +package ncli_test + +import ( + "git.noahlan.cn/noahlan/ntool/ncli" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestReadFirst(t *testing.T) { + // testutil.RewriteStdout() + // _, err := os.Stdout.WriteString("haha") + // ans, err1 := ncli.ReadFirst("hi?") + // testutil.RestoreStdout() + // assert.NoError(t, err) + // assert.NoError(t, err1) + // assert.Equal(t, "haha", ans) +} + +func TestInputIsYes(t *testing.T) { + tests := []struct { + in string + wnt bool + }{ + {"y", true}, + {"yes", true}, + {"yES", true}, + {"Y", true}, + {"Yes", true}, + {"YES", true}, + {"h", false}, + {"n", false}, + {"no", false}, + {"NO", false}, + } + for _, test := range tests { + assert.Eq(t, test.wnt, ncli.InputIsYes(test.in)) + } +} + +func TestByteIsYes(t *testing.T) { + tests := []struct { + in byte + wnt bool + }{ + {'y', true}, + {'Y', true}, + {'h', false}, + {'n', false}, + {'N', false}, + } + for _, test := range tests { + assert.Eq(t, test.wnt, ncli.ByteIsYes(test.in)) + } +} diff --git a/ncli/read_windows.go b/ncli/read_windows.go new file mode 100644 index 0000000..f90500c --- /dev/null +++ b/ncli/read_windows.go @@ -0,0 +1 @@ +package ncli diff --git a/ncli/util.go b/ncli/util.go new file mode 100644 index 0000000..9c535e5 --- /dev/null +++ b/ncli/util.go @@ -0,0 +1,147 @@ +package ncli + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "git.noahlan.cn/noahlan/ntool/ncli/cmdline" + "git.noahlan.cn/noahlan/ntool/nstr" + "strings" +) + +// LineBuild build command line string by given args. +func LineBuild(binFile string, args []string) string { + return cmdline.NewBuilder(binFile, args...).String() +} + +// BuildLine build command line string by given args. +func BuildLine(binFile string, args []string) string { + return cmdline.NewBuilder(binFile, args...).String() +} + +// String2OSArgs parse input command line text to os.Args +func String2OSArgs(line string) []string { + return cmdline.NewParser(line).Parse() +} + +// StringToOSArgs parse input command line text to os.Args +func StringToOSArgs(line string) []string { + return cmdline.NewParser(line).Parse() +} + +// ParseLine input command line text. alias of the StringToOSArgs() +func ParseLine(line string) []string { + return cmdline.NewParser(line).Parse() +} + +// QuickExec quick exec a simple command line +func QuickExec(cmdLine string, workDir ...string) (string, error) { + return ExecLine(cmdLine, workDir...) +} + +// ExecLine quick exec an command line string +func ExecLine(cmdLine string, workDir ...string) (string, error) { + p := cmdline.NewParser(cmdLine) + + // create a new Cmd instance + cmd := p.NewExecCmd() + if len(workDir) > 0 { + cmd.Dir = workDir[0] + } + + bs, err := cmd.Output() + return string(bs), err +} + +// ExecCommand alias of the ExecCmd() +func ExecCommand(binName string, args []string, workDir ...string) (string, error) { + return ExecCmd(binName, args, workDir...) +} + +// ExecCmd a command and return output. +// +// Usage: +// +// ExecCmd("ls", []string{"-al"}) +func ExecCmd(binName string, args []string, workDir ...string) (string, error) { + return common.ExecCmd(binName, args, workDir...) +} + +// ShellExec exec command by shell +// cmdLine e.g. "ls -al" +func ShellExec(cmdLine string, shells ...string) (string, error) { + return common.ShellExec(cmdLine, shells...) +} + +// CurrentShell get current used shell env file. +// +// eg "/bin/zsh" "/bin/bash". +// if onlyName=true, will return "zsh", "bash" +func CurrentShell(onlyName bool) (binPath string) { + return common.CurrentShell(onlyName) +} + +// HasShellEnv has shell env check. +// +// Usage: +// +// HasShellEnv("sh") +// HasShellEnv("bash") +func HasShellEnv(shell string) bool { + return common.HasShellEnv(shell) +} + +// BuildOptionHelpName for render flag help +func BuildOptionHelpName(names []string) string { + var sb strings.Builder + + size := len(names) - 1 + for i, name := range names { + sb.WriteByte('-') + if len(name) > 1 { + sb.WriteByte('-') + } + + sb.WriteString(name) + if i < size { + sb.WriteString(", ") + } + } + return sb.String() +} + +// ShellQuote quote a string on contains ', ", SPACE +func ShellQuote(s string) string { + var quote byte + if strings.ContainsRune(s, '"') { + quote = '\'' + } else if s == "" || strings.ContainsRune(s, '\'') || strings.ContainsRune(s, ' ') { + quote = '"' + } + + if quote > 0 { + ln := len(s) + 2 + bs := make([]byte, ln) + + bs[0] = quote + bs[ln-1] = quote + if ln > 2 { + copy(bs[1:ln-1], s) + } + + s = string(bs) + } + return s +} + +// OutputLines split output to lines +func OutputLines(output string) []string { + output = strings.TrimSuffix(output, "\n") + if output == "" { + return nil + } + return strings.Split(output, "\n") +} + +// FirstLine from command output +// +// Deprecated: please use nstr.FirstLine +var FirstLine = nstr.FirstLine diff --git a/ncli/util_test.go b/ncli/util_test.go new file mode 100644 index 0000000..2a27236 --- /dev/null +++ b/ncli/util_test.go @@ -0,0 +1,144 @@ +package ncli_test + +import ( + "git.noahlan.cn/noahlan/ntool/ncli" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "strings" + "testing" +) + +func TestCurrentShell(t *testing.T) { + path := ncli.CurrentShell(true) + + if path != "" { + assert.NotEmpty(t, path) + assert.True(t, ncli.HasShellEnv(path)) + + path = ncli.CurrentShell(false) + assert.NotEmpty(t, path) + } +} + +func TestExecCmd(t *testing.T) { + ret, err := ncli.ExecCmd("cmd", []string{"/c", "echo", "OK"}) + assert.NoErr(t, err) + // *nix: "OK\n" win: "OK\r\n" + assert.Eq(t, "OK", strings.TrimSpace(ret)) + + ret, err = ncli.ExecCommand("cmd", []string{"/c", "echo", "OK1"}) + assert.NoErr(t, err) + assert.Eq(t, "OK1", strings.TrimSpace(ret)) + + ret, err = ncli.QuickExec("cmd /c echo OK2") + assert.NoErr(t, err) + assert.Eq(t, "OK2", strings.TrimSpace(ret)) + + ret, err = ncli.ExecLine("cmd /c echo OK3") + assert.NoErr(t, err) + assert.Eq(t, "OK3", strings.TrimSpace(ret)) +} + +func TestShellExec(t *testing.T) { + ret, err := ncli.ShellExec("echo OK") + assert.NoErr(t, err) + // *nix: "OK\n" win: "OK\r\n" + assert.Eq(t, "OK", strings.TrimSpace(ret)) + + ret, err = ncli.ShellExec("echo OK", "powershell") + assert.NoErr(t, err) + assert.Eq(t, "OK", strings.TrimSpace(ret)) +} + +func TestLineBuild(t *testing.T) { + s := ncli.LineBuild("myapp", []string{"-a", "val0", "arg0"}) + assert.Eq(t, "myapp -a val0 arg0", s) + + s = ncli.BuildLine("./myapp", []string{ + "-a", "val0", + "-m", "this is message", + "arg0", + }) + assert.Eq(t, `./myapp -a val0 -m "this is message" arg0`, s) +} + +func TestParseLine(t *testing.T) { + args := ncli.ParseLine(`./app top sub -a ddd --xx "msg"`) + assert.Len(t, args, 7) + assert.Eq(t, "msg", args[6]) + + args = ncli.String2OSArgs(`./app top sub --msg "has inner 'quote'"`) + //dump.P(args) + assert.Len(t, args, 5) + assert.Eq(t, "has inner 'quote'", args[4]) + + // exception line string. + args = ncli.ParseLine(`./app top sub -a ddd --xx msg"`) + // dump.P(args) + assert.Len(t, args, 7) + assert.Eq(t, "msg\"", args[6]) + + args = ncli.StringToOSArgs(`./app top sub -a ddd --xx "msg "text"`) + // dump.P(args) + assert.Len(t, args, 7) + assert.Eq(t, "msg \"text", args[6]) +} + +func TestWorkdir(t *testing.T) { + assert.NotEmpty(t, ncli.Workdir()) + assert.NotEmpty(t, ncli.BinDir()) + assert.NotEmpty(t, ncli.BinFile()) + assert.NotEmpty(t, ncli.BinName()) +} + +func TestColorPrint(t *testing.T) { + // code gen by: kite gen parse ncli/_demo/gen-code.tpl + ncli.Redp("p:red color message, ") + ncli.Redf("f:%s color message, ", "red") + ncli.Redln("ln:red color message print in cli.") + ncli.Bluep("p:blue color message, ") + ncli.Bluef("f:%s color message, ", "blue") + ncli.Blueln("ln:blue color message print in cli.") + ncli.Cyanp("p:cyan color message, ") + ncli.Cyanf("f:%s color message, ", "cyan") + ncli.Cyanln("ln:cyan color message print in cli.") + ncli.Grayp("p:gray color message, ") + ncli.Grayf("f:%s color message, ", "gray") + ncli.Grayln("ln:gray color message print in cli.") + ncli.Greenp("p:green color message, ") + ncli.Greenf("f:%s color message, ", "green") + ncli.Greenln("ln:green color message print in cli.") + ncli.Yellowp("p:yellow color message, ") + ncli.Yellowf("f:%s color message, ", "yellow") + ncli.Yellowln("ln:yellow color message print in cli.") + ncli.Magentap("p:magenta color message, ") + ncli.Magentaf("f:%s color message, ", "magenta") + ncli.Magentaln("ln:magenta color message print in cli.") + + ncli.Infop("p:info color message, ") + ncli.Infof("f:%s color message, ", "info") + ncli.Infoln("ln:info color message print in cli.") + ncli.Successp("p:success color message, ") + ncli.Successf("f:%s color message, ", "success") + ncli.Successln("ln:success color message print in cli.") + ncli.Warnp("p:warn color message, ") + ncli.Warnf("f:%s color message, ", "warn") + ncli.Warnln("ln:warn color message print in cli.") + ncli.Errorp("p:error color message, ") + ncli.Errorf("f:%s color message, ", "error") + ncli.Errorln("ln:error color message print in cli.") +} + +func TestBuildOptionHelpName(t *testing.T) { + assert.Eq(t, "-a, -b", ncli.BuildOptionHelpName([]string{"a", "b"})) + assert.Eq(t, "-h, --help", ncli.BuildOptionHelpName([]string{"h", "help"})) +} + +func TestShellQuote(t *testing.T) { + assert.Eq(t, `"'"`, ncli.ShellQuote("'")) + assert.Eq(t, `""`, ncli.ShellQuote("")) + assert.Eq(t, `" "`, ncli.ShellQuote(" ")) + assert.Eq(t, `"ab s"`, ncli.ShellQuote("ab s")) + assert.Eq(t, `"ab's"`, ncli.ShellQuote("ab's")) + assert.Eq(t, `'ab"s'`, ncli.ShellQuote(`ab"s`)) + assert.Eq(t, "abs", ncli.ShellQuote("abs")) +} diff --git a/ncrypt/aes_des.go b/ncrypt/aes_des.go new file mode 100644 index 0000000..5041ee7 --- /dev/null +++ b/ncrypt/aes_des.go @@ -0,0 +1,411 @@ +package ncrypt + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rand" + "errors" + "io" +) + +// ErrUnPadding error +var ErrUnPadding = errors.New("un-padding decrypted data fail") + +func GenerateAesKey(key []byte, size int) []byte { + genKey := make([]byte, size) + copy(genKey, key) + for i := size; i < len(key); { + for j := 0; j < size && i < len(key); j, i = j+1, i+1 { + genKey[j] ^= key[i] + } + } + return genKey +} + +func GenerateDesKey(key []byte) []byte { + genKey := make([]byte, 8) + copy(genKey, key) + for i := 8; i < len(key); { + for j := 0; j < 8 && i < len(key); j, i = j+1, i+1 { + genKey[j] ^= key[i] + } + } + return genKey +} + +// PKCS5Padding input data +func PKCS5Padding(src []byte, blockSize int) []byte { + padding := blockSize - len(src)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padText...) +} + +// PKCS5UnPadding input data +func PKCS5UnPadding(src []byte) ([]byte, error) { + length := len(src) + delLen := int(src[length-1]) + + if delLen > length { + return nil, ErrUnPadding + } + + // fix: 检查删除的填充是否是一样的字符,不一样说明 delLen 值是有问题的,无法解码 + if delLen > 1 && src[length-1] != src[length-2] { + return nil, ErrUnPadding + } + + return src[:length-delLen], nil +} + +// PKCS7Padding input data +func PKCS7Padding(src []byte, blockSize int) []byte { + return PKCS5Padding(src, blockSize) +} + +// PKCS7UnPadding input data +func PKCS7UnPadding(src []byte) ([]byte, error) { + return PKCS5UnPadding(src) +} + +// AesEcbEncrypt encrypt data with key use AES ECB algorithm +// len(key) should be 16, 24 or 32. +func AesEcbEncrypt(data, key []byte) []byte { + size := len(key) + if size != 16 && size != 24 && size != 32 { + panic("key length shoud be 16 or 24 or 32") + } + + length := (len(data) + aes.BlockSize) / aes.BlockSize + plain := make([]byte, length*aes.BlockSize) + + copy(plain, data) + + pad := byte(len(plain) - len(data)) + for i := len(data); i < len(plain); i++ { + plain[i] = pad + } + + encrypted := make([]byte, len(plain)) + cipher, _ := aes.NewCipher(GenerateAesKey(key, size)) + + for bs, be := 0, cipher.BlockSize(); bs <= len(data); bs, be = bs+cipher.BlockSize(), be+cipher.BlockSize() { + cipher.Encrypt(encrypted[bs:be], plain[bs:be]) + } + + return encrypted +} + +// AesEcbDecrypt decrypt data with key use AES ECB algorithm +// len(key) should be 16, 24 or 32. +func AesEcbDecrypt(encrypted, key []byte) []byte { + size := len(key) + if size != 16 && size != 24 && size != 32 { + panic("key length should be 16 or 24 or 32") + } + cipher, _ := aes.NewCipher(GenerateAesKey(key, size)) + decrypted := make([]byte, len(encrypted)) + + for bs, be := 0, cipher.BlockSize(); bs < len(encrypted); bs, be = bs+cipher.BlockSize(), be+cipher.BlockSize() { + cipher.Decrypt(decrypted[bs:be], encrypted[bs:be]) + } + + trim := 0 + if len(decrypted) > 0 { + trim = len(decrypted) - int(decrypted[len(decrypted)-1]) + } + + return decrypted[:trim] +} + +// AesCbcEncrypt encrypt data with key use AES CBC algorithm +// len(key) should be 16, 24 or 32. +func AesCbcEncrypt(data, key []byte) []byte { + block, _ := aes.NewCipher(key) + data = PKCS7Padding(data, block.BlockSize()) + + encrypted := make([]byte, aes.BlockSize+len(data)) + iv := encrypted[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + panic(err) + } + + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(encrypted[aes.BlockSize:], data) + + return encrypted +} + +// AesCbcDecrypt decrypt data with key use AES CBC algorithm +// len(key) should be 16, 24 or 32. +func AesCbcDecrypt(encrypted, key []byte) ([]byte, error) { + block, _ := aes.NewCipher(key) + + iv := encrypted[:aes.BlockSize] + encrypted = encrypted[aes.BlockSize:] + + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(encrypted, encrypted) + + return PKCS7UnPadding(encrypted) +} + +// AesCtrCrypt encrypt data with key use AES CTR algorithm +// len(key) should be 16, 24 or 32. +func AesCtrCrypt(data, key []byte) []byte { + block, _ := aes.NewCipher(key) + + iv := bytes.Repeat([]byte("1"), block.BlockSize()) + stream := cipher.NewCTR(block, iv) + + dst := make([]byte, len(data)) + stream.XORKeyStream(dst, data) + + return dst +} + +// AesCfbEncrypt encrypt data with key use AES CFB algorithm +// len(key) should be 16, 24 or 32. +func AesCfbEncrypt(data, key []byte) []byte { + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + encrypted := make([]byte, aes.BlockSize+len(data)) + iv := encrypted[:aes.BlockSize] + + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + panic(err) + } + + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(encrypted[aes.BlockSize:], data) + + return encrypted +} + +// AesCfbDecrypt decrypt data with key use AES CFB algorithm +// len(encrypted) should be greater than 16, len(key) should be 16, 24 or 32. +func AesCfbDecrypt(encrypted, key []byte) []byte { + if len(encrypted) < aes.BlockSize { + panic("encrypted data is too short") + } + + block, _ := aes.NewCipher(key) + iv := encrypted[:aes.BlockSize] + encrypted = encrypted[aes.BlockSize:] + + stream := cipher.NewCFBDecrypter(block, iv) + + stream.XORKeyStream(encrypted, encrypted) + + return encrypted +} + +// AesOfbEncrypt encrypt data with key use AES OFB algorithm +// len(key) should be 16, 24 or 32. +func AesOfbEncrypt(data, key []byte) []byte { + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + data = PKCS7Padding(data, aes.BlockSize) + encrypted := make([]byte, aes.BlockSize+len(data)) + iv := encrypted[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + panic(err) + } + + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(encrypted[aes.BlockSize:], data) + + return encrypted +} + +// AesOfbDecrypt decrypt data with key use AES OFB algorithm +// len(key) should be 16, 24 or 32. +func AesOfbDecrypt(data, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + iv := data[:aes.BlockSize] + data = data[aes.BlockSize:] + if len(data)%aes.BlockSize != 0 { + return nil, errors.New("data must % 16") + } + + decrypted := make([]byte, len(data)) + mode := cipher.NewOFB(block, iv) + mode.XORKeyStream(decrypted, data) + + return PKCS7UnPadding(decrypted) +} + +// DesEcbEncrypt encrypt data with key use DES ECB algorithm +// len(key) should be 8. +func DesEcbEncrypt(data, key []byte) []byte { + length := (len(data) + des.BlockSize) / des.BlockSize + plain := make([]byte, length*des.BlockSize) + copy(plain, data) + + pad := byte(len(plain) - len(data)) + for i := len(data); i < len(plain); i++ { + plain[i] = pad + } + + encrypted := make([]byte, len(plain)) + cipher, _ := des.NewCipher(GenerateDesKey(key)) + + for bs, be := 0, cipher.BlockSize(); bs <= len(data); bs, be = bs+cipher.BlockSize(), be+cipher.BlockSize() { + cipher.Encrypt(encrypted[bs:be], plain[bs:be]) + } + + return encrypted +} + +// DesEcbDecrypt decrypt data with key use DES ECB algorithm +// len(key) should be 8. +func DesEcbDecrypt(encrypted, key []byte) []byte { + cipher, _ := des.NewCipher(GenerateDesKey(key)) + decrypted := make([]byte, len(encrypted)) + + for bs, be := 0, cipher.BlockSize(); bs < len(encrypted); bs, be = bs+cipher.BlockSize(), be+cipher.BlockSize() { + cipher.Decrypt(decrypted[bs:be], encrypted[bs:be]) + } + + trim := 0 + if len(decrypted) > 0 { + trim = len(decrypted) - int(decrypted[len(decrypted)-1]) + } + + return decrypted[:trim] +} + +// DesCbcEncrypt encrypt data with key use DES CBC algorithm +// len(key) should be 8. +func DesCbcEncrypt(data, key []byte) []byte { + block, _ := des.NewCipher(key) + data = PKCS7Padding(data, block.BlockSize()) + + encrypted := make([]byte, des.BlockSize+len(data)) + iv := encrypted[:des.BlockSize] + + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + panic(err) + } + + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(encrypted[des.BlockSize:], data) + + return encrypted +} + +// DesCbcDecrypt decrypt data with key use DES CBC algorithm +// len(key) should be 8. +func DesCbcDecrypt(encrypted, key []byte) ([]byte, error) { + block, _ := des.NewCipher(key) + + iv := encrypted[:des.BlockSize] + encrypted = encrypted[des.BlockSize:] + + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(encrypted, encrypted) + + return PKCS7UnPadding(encrypted) +} + +// DesCtrCrypt encrypt data with key use DES CTR algorithm +// len(key) should be 8. +func DesCtrCrypt(data, key []byte) []byte { + block, _ := des.NewCipher(key) + + iv := bytes.Repeat([]byte("1"), block.BlockSize()) + stream := cipher.NewCTR(block, iv) + + dst := make([]byte, len(data)) + stream.XORKeyStream(dst, data) + + return dst +} + +// DesCfbEncrypt encrypt data with key use DES CFB algorithm +// len(key) should be 8. +func DesCfbEncrypt(data, key []byte) []byte { + block, err := des.NewCipher(key) + if err != nil { + panic(err) + } + + encrypted := make([]byte, des.BlockSize+len(data)) + iv := encrypted[:des.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + panic(err) + } + + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(encrypted[des.BlockSize:], data) + + return encrypted +} + +// DesCfbDecrypt decrypt data with key use DES CFB algorithm +// len(encrypted) should be greater than 16, len(key) should be 8. +func DesCfbDecrypt(encrypted, key []byte) []byte { + block, _ := des.NewCipher(key) + if len(encrypted) < des.BlockSize { + panic("encrypted data is too short") + } + iv := encrypted[:des.BlockSize] + encrypted = encrypted[des.BlockSize:] + + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(encrypted, encrypted) + + return encrypted +} + +// DesOfbEncrypt encrypt data with key use DES OFB algorithm +// len(key) should be 16, 24 or 32. +func DesOfbEncrypt(data, key []byte) []byte { + block, err := des.NewCipher(key) + if err != nil { + panic(err) + } + data = PKCS7Padding(data, des.BlockSize) + encrypted := make([]byte, des.BlockSize+len(data)) + iv := encrypted[:des.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + panic(err) + } + + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(encrypted[des.BlockSize:], data) + + return encrypted +} + +// DesOfbDecrypt decrypt data with key use DES OFB algorithm +// len(key) should be 8. +func DesOfbDecrypt(data, key []byte) ([]byte, error) { + block, err := des.NewCipher(key) + if err != nil { + panic(err) + } + + iv := data[:des.BlockSize] + data = data[des.BlockSize:] + if len(data)%des.BlockSize != 0 { + return nil, errors.New("data must % 16") + } + + decrypted := make([]byte, len(data)) + mode := cipher.NewOFB(block, iv) + mode.XORKeyStream(decrypted, data) + + return PKCS7UnPadding(decrypted) +} diff --git a/ncrypt/base64.go b/ncrypt/base64.go new file mode 100644 index 0000000..1a6aa52 --- /dev/null +++ b/ncrypt/base64.go @@ -0,0 +1,14 @@ +package ncrypt + +import "git.noahlan.cn/noahlan/ntool/nbyte" + +// Base64Encode encode data with base64 encoding. +func Base64Encode(s []byte) []byte { + return nbyte.B64Encoder.Encode(s) +} + +// Base64EncodeStr encode string data with base64 encoding. +func Base64EncodeStr(s string) string { + bs := Base64Encode([]byte(s)) + return string(bs) +} diff --git a/ncrypt/bcrypt.go b/ncrypt/bcrypt.go new file mode 100644 index 0000000..728594b --- /dev/null +++ b/ncrypt/bcrypt.go @@ -0,0 +1,18 @@ +package ncrypt + +import "golang.org/x/crypto/bcrypt" + +// BcryptEncrypt Bcrypt 加密 +func BcryptEncrypt(password string, code int) string { + if code == 0 { + code = bcrypt.DefaultCost + } + bs, _ := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + return string(bs) +} + +// BcryptCheck Bcrypt 检查 +func BcryptCheck(password, hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) + return err == nil +} diff --git a/ncrypt/hmac.go b/ncrypt/hmac.go new file mode 100644 index 0000000..f11fdef --- /dev/null +++ b/ncrypt/hmac.go @@ -0,0 +1,38 @@ +package ncrypt + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" +) + +// HmacMd5 return the hmac hash of string use md5. +func HmacMd5(data, key string) string { + h := hmac.New(md5.New, []byte(key)) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum([]byte(""))) +} + +// HmacSha1 return the hmac hash of string use sha1. +func HmacSha1(data, key string) string { + h := hmac.New(sha1.New, []byte(key)) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum([]byte(""))) +} + +// HmacSha256 return the hmac hash of string use sha256. +func HmacSha256(data, key string) string { + h := hmac.New(sha256.New, []byte(key)) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum([]byte(""))) +} + +// HmacSha512 return the hmac hash of string use sha512. +func HmacSha512(data, key string) string { + h := hmac.New(sha512.New, []byte(key)) + h.Write([]byte(data)) + return hex.EncodeToString(h.Sum([]byte(""))) +} diff --git a/ncrypt/md5.go b/ncrypt/md5.go new file mode 100644 index 0000000..605980c --- /dev/null +++ b/ncrypt/md5.go @@ -0,0 +1,52 @@ +package ncrypt + +import ( + "bufio" + "crypto/md5" + "encoding/hex" + "fmt" + "git.noahlan.cn/noahlan/ntool/nbyte" + "io" + "os" +) + +// Md5Bytes return the md5 value of bytes. +func Md5Bytes(b any) []byte { + return nbyte.Md5(b) +} + +// Md5String return the md5 value of string. +func Md5String(s any) string { + return hex.EncodeToString(Md5Bytes(s)) +} + +// Md5File return the md5 value of file. +func Md5File(filename string) (string, error) { + if fileInfo, err := os.Stat(filename); err != nil { + return "", err + } else if fileInfo.IsDir() { + return "", nil + } + + file, err := os.Open(filename) + if err != nil { + return "", err + } + defer file.Close() + + hash := md5.New() + chunkSize := 65536 + for buf, reader := make([]byte, chunkSize), bufio.NewReader(file); ; { + n, err := reader.Read(buf) + if err != nil { + if err == io.EOF { + break + } + return "", err + } + hash.Write(buf[:n]) + } + + checksum := fmt.Sprintf("%x", hash.Sum(nil)) + return checksum, nil +} diff --git a/ncrypt/rsa.go b/ncrypt/rsa.go new file mode 100644 index 0000000..3c33f53 --- /dev/null +++ b/ncrypt/rsa.go @@ -0,0 +1,128 @@ +package ncrypt + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" +) + +// GenerateRsaKey create rsa private and public pemo file. +func GenerateRsaKey(keySize int, priKeyFile, pubKeyFile string) error { + // private key + privateKey, err := rsa.GenerateKey(rand.Reader, keySize) + if err != nil { + return err + } + + derText := x509.MarshalPKCS1PrivateKey(privateKey) + + block := pem.Block{ + Type: "rsa private key", + Bytes: derText, + } + + file, err := os.Create(priKeyFile) + if err != nil { + panic(err) + } + err = pem.Encode(file, &block) + if err != nil { + return err + } + + file.Close() + + // public key + publicKey := privateKey.PublicKey + + derpText, err := x509.MarshalPKIXPublicKey(&publicKey) + if err != nil { + return err + } + + block = pem.Block{ + Type: "rsa public key", + Bytes: derpText, + } + + file, err = os.Create(pubKeyFile) + if err != nil { + return err + } + + err = pem.Encode(file, &block) + if err != nil { + return err + } + + file.Close() + + return nil +} + +// RsaEncrypt encrypt data with ras algorithm. +func RsaEncrypt(data []byte, pubKeyFileName string) []byte { + file, err := os.Open(pubKeyFileName) + if err != nil { + panic(err) + } + fileInfo, err := file.Stat() + if err != nil { + panic(err) + } + defer file.Close() + buf := make([]byte, fileInfo.Size()) + + _, err = file.Read(buf) + if err != nil { + panic(err) + } + + block, _ := pem.Decode(buf) + + pubInterface, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + pubKey := pubInterface.(*rsa.PublicKey) + + cipherText, err := rsa.EncryptPKCS1v15(rand.Reader, pubKey, data) + if err != nil { + panic(err) + } + return cipherText +} + +// RsaDecrypt decrypt data with ras algorithm. +func RsaDecrypt(data []byte, privateKeyFileName string) []byte { + file, err := os.Open(privateKeyFileName) + if err != nil { + panic(err) + } + fileInfo, err := file.Stat() + if err != nil { + panic(err) + } + buf := make([]byte, fileInfo.Size()) + defer file.Close() + + _, err = file.Read(buf) + if err != nil { + panic(err) + } + + block, _ := pem.Decode(buf) + + priKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + panic(err) + } + + plainText, err := rsa.DecryptPKCS1v15(rand.Reader, priKey, data) + if err != nil { + panic(err) + } + return plainText +} diff --git a/ncrypt/sha.go b/ncrypt/sha.go new file mode 100644 index 0000000..555b40c --- /dev/null +++ b/ncrypt/sha.go @@ -0,0 +1,29 @@ +package ncrypt + +import ( + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" +) + +// Sha1 return the sha1 value (SHA-1 hash algorithm) of string. +func Sha1(data string) string { + s := sha1.New() + s.Write([]byte(data)) + return hex.EncodeToString(s.Sum([]byte(""))) +} + +// Sha256 return the sha256 value (SHA256 hash algorithm) of string. +func Sha256(data string) string { + s := sha256.New() + s.Write([]byte(data)) + return hex.EncodeToString(s.Sum([]byte(""))) +} + +// Sha512 return the sha512 value (SHA512 hash algorithm) of string. +func Sha512(data string) string { + s := sha512.New() + s.Write([]byte(data)) + return hex.EncodeToString(s.Sum([]byte(""))) +} diff --git a/ndef/consts.go b/ndef/consts.go new file mode 100644 index 0000000..84e66ef --- /dev/null +++ b/ndef/consts.go @@ -0,0 +1,27 @@ +package ndef + +// const for compare operation +const ( + OpEq = "=" + OpNeq = "!=" + OpLt = "<" + OpLte = "<=" + OpGt = ">" + OpGte = ">=" +) + +// const quote chars +const ( + SingleQuote = '\'' + DoubleQuote = '"' + SlashQuote = '\\' + + SingleQuoteStr = "'" + DoubleQuoteStr = `"` + SlashQuoteStr = "\\" +) + +// NoIdx invalid index or length +const NoIdx = -1 + +// const VarPathReg = `(\w[\w-]*(?:\.[\w-]+)*)` diff --git a/ndef/errors.go b/ndef/errors.go new file mode 100644 index 0000000..4f93c1f --- /dev/null +++ b/ndef/errors.go @@ -0,0 +1,6 @@ +package ndef + +import "errors" + +// ErrConvType error +var ErrConvType = errors.New("convert value type error") diff --git a/ndef/formatter.go b/ndef/formatter.go new file mode 100644 index 0000000..cc1fe63 --- /dev/null +++ b/ndef/formatter.go @@ -0,0 +1,56 @@ +package ndef + +import ( + "bytes" + nio "git.noahlan.cn/noahlan/ntool/nstd/io" + "io" +) + +// DataFormatter interface +type DataFormatter interface { + Format() string + FormatTo(w io.Writer) +} + +// BaseFormatter struct +type BaseFormatter struct { + ow nio.ByteStringWriter + // Out formatted to the writer + Out io.Writer + // Src data(array, map, struct) for format + Src any + // MaxDepth limit depth for array, map data TODO + MaxDepth int + // Prefix string for each element + Prefix string + // Indent string for format each element + Indent string + // ClosePrefix string for last "]", "}" + ClosePrefix string +} + +// Reset after format +func (f *BaseFormatter) Reset() { + f.Out = nil + f.Src = nil +} + +// SetOutput writer +func (f *BaseFormatter) SetOutput(out io.Writer) { + f.Out = out +} + +// BsWriter build and get +func (f *BaseFormatter) BsWriter() nio.ByteStringWriter { + if f.ow == nil { + if f.Out == nil { + f.ow = new(bytes.Buffer) + } else if ow, ok := f.Out.(nio.ByteStringWriter); ok { + f.ow = ow + } else { + f.ow = nio.NewWriteWrapper(f.Out) + } + } + + return f.ow +} diff --git a/ndef/serializer.go b/ndef/serializer.go new file mode 100644 index 0000000..6acd558 --- /dev/null +++ b/ndef/serializer.go @@ -0,0 +1,37 @@ +package ndef + +type ( + MarshalFunc func(v any) ([]byte, error) + UnmarshalFunc func(data []byte, v any) error + + Marshaler interface { + Marshal(v any) ([]byte, error) + } + + Unmarshaler interface { + Unmarshal(data []byte, v any) error + } + + Serializer interface { + Marshaler + Unmarshaler + } +) + +type ( + MarshalerWrapper struct { + Marshaler + } + UnmarshalerWrapper struct { + Unmarshaler + } + SerializerWrapper struct { + Marshaler + Unmarshaler + } +) + +// NewSerializerWrapper 序列化器包装,用于将序列化/反序列化包装为一个独立结构 +func NewSerializerWrapper(marshaler Marshaler, unmarshaler Unmarshaler) Serializer { + return &SerializerWrapper{Marshaler: marshaler, Unmarshaler: unmarshaler} +} diff --git a/ndef/symbols.go b/ndef/symbols.go new file mode 100644 index 0000000..b0b3223 --- /dev/null +++ b/ndef/symbols.go @@ -0,0 +1,41 @@ +package ndef + +const ( + // CommaStr const define + CommaStr = "," + // CommaChar define + CommaChar = ',' + + // EqualStr define + EqualStr = "=" + // EqualChar define + EqualChar = '=' + + // ColonStr define + ColonStr = ":" + // ColonChar define + ColonChar = ':' + + // SemicolonStr semicolon define + SemicolonStr = ";" + // SemicolonChar define + SemicolonChar = ';' + + // PathStr define const + PathStr = "/" + // PathChar define + PathChar = '/' + + // DefaultSep comma string + DefaultSep = "," + + // SpaceChar char + SpaceChar = ' ' + // SpaceStr string + SpaceStr = " " + + // NewlineChar char + NewlineChar = '\n' + // NewlineStr string + NewlineStr = "\n" +) diff --git a/ndef/types.go b/ndef/types.go new file mode 100644 index 0000000..dbac06d --- /dev/null +++ b/ndef/types.go @@ -0,0 +1,46 @@ +package ndef + +// Int interface type +type Int interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// Uint interface type +type Uint interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uintptr +} + +// XInt interface type. all int or uint types +type XInt interface { + Int | Uint +} + +// Float interface type +type Float interface { + ~float32 | ~float64 +} + +// IntOrFloat interface type. all int and float types +type IntOrFloat interface { + Int | Float +} + +// XIntOrFloat interface type. all int, uint and float types +type XIntOrFloat interface { + Int | Uint | Float +} + +// SortedType interface type. +// that supports the operators < <= >= >. +// +// contains: (x)int, float, ~string types +type SortedType interface { + Int | Uint | Float | ~string +} + +// ScalarType interface type. +// +// contains: (x)int, float, ~string, ~bool types +type ScalarType interface { + Int | Uint | Float | ~string | ~bool +} diff --git a/nenv/info.go b/nenv/info.go new file mode 100644 index 0000000..1a21cb1 --- /dev/null +++ b/nenv/info.go @@ -0,0 +1,169 @@ +package nenv + +import ( + "git.noahlan.cn/noahlan/ntool/nsys" + "golang.org/x/term" + "io" + "os" + "runtime" + "strings" +) + +// IsWin system. linux windows darwin +func IsWin() bool { + return runtime.GOOS == "windows" +} + +// IsWindows system. alias of IsWin +func IsWindows() bool { + return runtime.GOOS == "windows" +} + +// IsMac system +func IsMac() bool { + return runtime.GOOS == "darwin" +} + +// IsLinux system +func IsLinux() bool { + return runtime.GOOS == "linux" +} + +// IsMSys msys(MINGW64) env. alias of the nsys.IsMSys() +func IsMSys() bool { + return nsys.IsMSys() +} + +var detectedWSL bool +var detectedWSLContents string + +// IsWSL system env +// https://github.com/Microsoft/WSL/issues/423#issuecomment-221627364 +func IsWSL() bool { + if !detectedWSL { + b := make([]byte, 1024) + // `cat /proc/version` + // on mac: + // !not the file! + // on linux(debian,ubuntu,alpine): + // Linux version 4.19.121-linuxkit (root@18b3f92ade35) (gcc version 9.2.0 (Alpine 9.2.0)) #1 SMP Thu Jan 21 15:36:34 UTC 2021 + // on win git bash, conEmu: + // MINGW64_NT-10.0-19042 version 3.1.7-340.x86_64 (@WIN-N0G619FD3UK) (gcc version 9.3.0 (GCC) ) 2020-10-23 13:08 UTC + // on WSL: + // Linux version 4.4.0-19041-Microsoft (Microsoft@Microsoft.com) (gcc version 5.4.0 (GCC) ) #488-Microsoft Mon Sep 01 13:43:00 PST 2020 + f, err := os.Open("/proc/version") + if err == nil { + _, _ = f.Read(b) // ignore error + f.Close() + detectedWSLContents = string(b) + } + detectedWSL = true + } + return strings.Contains(detectedWSLContents, "Microsoft") +} + +// IsTerminal isatty check +// +// Usage: +// +// envutil.IsTerminal(os.Stdout.Fd()) +func IsTerminal(fd uintptr) bool { + // return isatty.IsTerminal(fd) // "github.com/mattn/go-isatty" + return term.IsTerminal(int(fd)) +} + +// StdIsTerminal os.Stdout is terminal +func StdIsTerminal() bool { + return IsTerminal(os.Stdout.Fd()) +} + +// IsConsole check out is console env. alias of the nsys.IsConsole() +func IsConsole(out io.Writer) bool { + return nsys.IsConsole(out) +} + +// HasShellEnv has shell env check. +// +// Usage: +// +// HasShellEnv("sh") +// HasShellEnv("bash") +func HasShellEnv(shell string) bool { + return nsys.HasShellEnv(shell) +} + +// Support color: +// +// "TERM=xterm" +// "TERM=xterm-vt220" +// "TERM=xterm-256color" +// "TERM=screen-256color" +// "TERM=tmux-256color" +// "TERM=rxvt-unicode-256color" +// +// Don't support color: +// +// "TERM=cygwin" +var specialColorTerms = map[string]bool{ + "alacritty": true, +} + +// IsSupportColor check current console is support color. +// +// Supported: +// +// linux, mac, or windows's ConEmu, Cmder, putty, git-bash.exe +// +// Not support: +// +// windows cmd.exe, powerShell.exe +func IsSupportColor() bool { + envTerm := os.Getenv("TERM") + if strings.Contains(envTerm, "xterm") { + return true + } + + // it's special color term + if _, ok := specialColorTerms[envTerm]; ok { + return true + } + + // like on ConEmu software, e.g "ConEmuANSI=ON" + if os.Getenv("ConEmuANSI") == "ON" { + return true + } + + // like on ConEmu software, e.g "ANSICON=189x2000 (189x43)" + if os.Getenv("ANSICON") != "" { + return true + } + + // up: if support 256-color, can also support basic color. + return IsSupport256Color() +} + +// IsSupport256Color render +func IsSupport256Color() bool { + // "TERM=xterm-256color" + // "TERM=screen-256color" + // "TERM=tmux-256color" + // "TERM=rxvt-unicode-256color" + supported := strings.Contains(os.Getenv("TERM"), "256color") + if !supported { + // up: if support true-color, can also support 256-color. + supported = IsSupportTrueColor() + } + + return supported +} + +// IsSupportTrueColor render. IsSupportRGBColor +func IsSupportTrueColor() bool { + // "COLORTERM=truecolor" + return strings.Contains(os.Getenv("COLORTERM"), "truecolor") +} + +// IsGithubActions env +func IsGithubActions() bool { + return os.Getenv("GITHUB_ACTIONS") == "true" +} diff --git a/nenv/parse.go b/nenv/parse.go new file mode 100644 index 0000000..d5e7ba6 --- /dev/null +++ b/nenv/parse.go @@ -0,0 +1,25 @@ +package nenv + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" +) + +// Environ like os.Environ, but will returns key-value map[string]string data. +func Environ() map[string]string { + return common.Environ() +} + +// ParseEnvVar parse ENV var value from input string, support default value. +// +// Format: +// +// ${var_name} Only var name +// ${var_name | default} With default value +// +// Usage: +// +// comfunc.ParseEnvVar("${ APP_NAME }") +// comfunc.ParseEnvVar("${ APP_ENV | dev }") +func ParseEnvVar(val string, getFn func(string) string) (newVal string) { + return common.ParseEnvVar(val, getFn) +} diff --git a/nfs/check.go b/nfs/check.go new file mode 100644 index 0000000..49f5280 --- /dev/null +++ b/nfs/check.go @@ -0,0 +1,136 @@ +package nfs + +import ( + "bytes" + "os" + "path" + "path/filepath" +) + +// perm for create dir or file +var ( + DefaultDirPerm os.FileMode = 0775 + DefaultFilePerm os.FileMode = 0665 + OnlyReadFilePerm os.FileMode = 0444 +) + +var ( + // DefaultFileFlags for create and write + DefaultFileFlags = os.O_CREATE | os.O_WRONLY | os.O_APPEND + // OnlyReadFileFlags open file for read + OnlyReadFileFlags = os.O_RDONLY +) + +// PathExists reports whether the named file or directory exists. +func PathExists(path string) bool { + if path == "" { + return false + } + + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + return true +} + +// DirExists reports whether the named directory exists. +func DirExists(path string) bool { + return IsDir(path) +} + +// IsDir reports whether the named directory exists. +func IsDir(path string) bool { + if path == "" || len(path) > 468 { + return false + } + + if fi, err := os.Stat(path); err == nil { + return fi.IsDir() + } + return false +} + +// FileExists reports whether the named file or directory exists. +func FileExists(path string) bool { + return IsFile(path) +} + +// IsFile reports whether the named file or directory exists. +func IsFile(path string) bool { + if path == "" || len(path) > 468 { + return false + } + + if fi, err := os.Stat(path); err == nil { + return !fi.IsDir() + } + return false +} + +// IsAbsPath is abs path. +func IsAbsPath(aPath string) bool { + if len(aPath) > 0 { + if aPath[0] == '/' { + return true + } + return filepath.IsAbs(aPath) + } + return false +} + +// ImageMimeTypes refer net/http package +var ImageMimeTypes = map[string]string{ + "bmp": "image/bmp", + "gif": "image/gif", + "ief": "image/ief", + "jpg": "image/jpeg", + // "jpe": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "svg": "image/svg+xml", + "ico": "image/x-icon", + "webp": "image/webp", +} + +// IsImageFile check file is image file. +func IsImageFile(path string) bool { + mime := MimeType(path) + if mime == "" { + return false + } + + for _, imgMime := range ImageMimeTypes { + if imgMime == mime { + return true + } + } + return false +} + +// IsZipFile check is zip file. +// from https://blog.csdn.net/wangshubo1989/article/details/71743374 +func IsZipFile(filepath string) bool { + f, err := os.Open(filepath) + if err != nil { + return false + } + defer f.Close() + + buf := make([]byte, 4) + if n, err := f.Read(buf); err != nil || n < 4 { + return false + } + + return bytes.Equal(buf, []byte("PK\x03\x04")) +} + +// PathMatch check for a string. alias of path.Match() +func PathMatch(pattern, s string) bool { + ok, err := path.Match(pattern, s) + if err != nil { + ok = false + } + return ok +} diff --git a/nfs/check_test.go b/nfs/check_test.go new file mode 100644 index 0000000..bb28f63 --- /dev/null +++ b/nfs/check_test.go @@ -0,0 +1,87 @@ +package nfs_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "runtime" + "testing" +) + +//goland:noinspection GoBoolExpressions +func TestNfs_common(t *testing.T) { + assert.Eq(t, "", nfs.FileExt("testdata/testjpg")) + assert.Eq(t, "", nfs.Suffix("testdata/testjpg")) + assert.Eq(t, "", nfs.ExtName("testdata/testjpg")) + assert.Eq(t, ".txt", nfs.FileExt("testdata/test.txt")) + assert.Eq(t, ".txt", nfs.Suffix("testdata/test.txt")) + assert.Eq(t, "txt", nfs.ExtName("testdata/test.txt")) + + // IsZipFile + assert.False(t, nfs.IsZipFile("testdata/not-exists-file")) + assert.False(t, nfs.IsZipFile("testdata/test.txt")) + assert.Eq(t, "test.txt", nfs.PathName("testdata/test.txt")) + + assert.Eq(t, "test.txt", nfs.Name("path/to/test.txt")) + assert.Eq(t, "", nfs.Name("")) + + if runtime.GOOS == "windows" { + assert.Eq(t, "path\\to", nfs.Dir("path/to/test.txt")) + } else { + assert.Eq(t, "path/to", nfs.Dir("path/to/test.txt")) + } +} + +func TestPathExists(t *testing.T) { + assert.False(t, nfs.PathExists("")) + assert.False(t, nfs.PathExists("/not-exist")) + assert.False(t, nfs.PathExists("/not-exist")) + assert.True(t, nfs.PathExists("testdata/test.txt")) + assert.True(t, nfs.PathExists("testdata/test.txt")) +} + +func TestIsFile(t *testing.T) { + assert.False(t, nfs.FileExists("")) + assert.False(t, nfs.IsFile("")) + assert.False(t, nfs.IsFile("/not-exist")) + assert.False(t, nfs.FileExists("/not-exist")) + assert.True(t, nfs.IsFile("testdata/test.txt")) + assert.True(t, nfs.FileExists("testdata/test.txt")) +} + +func TestIsDir(t *testing.T) { + assert.False(t, nfs.IsDir("")) + assert.False(t, nfs.DirExists("")) + assert.False(t, nfs.IsDir("/not-exist")) + assert.True(t, nfs.IsDir("testdata")) + assert.True(t, nfs.DirExists("testdata")) +} + +func TestIsAbsPath(t *testing.T) { + assert.True(t, nfs.IsAbsPath("/data/some.txt")) + assert.False(t, nfs.IsAbsPath("")) + assert.False(t, nfs.IsAbsPath("some.txt")) + assert.NoErr(t, nfs.DeleteIfFileExist("/not-exist")) +} + +func TestGlobMatch(t *testing.T) { + tests := []struct { + p, s string + want bool + }{ + {"a*", "abc", true}, + {"ab.*.ef", "ab.cd.ef", true}, + {"ab.*.*", "ab.cd.ef", true}, + {"ab.cd.*", "ab.cd.ef", true}, + {"ab.*", "ab.cd.ef", true}, + {"a*/b", "a/c/b", false}, + {"a*", "a/c/b", false}, + {"a**", "a/c/b", false}, + } + + for _, tt := range tests { + assert.Eq(t, tt.want, nfs.PathMatch(tt.p, tt.s), "case %v", tt) + } + + assert.False(t, nfs.PathMatch("ab", "abc")) + assert.True(t, nfs.PathMatch("ab*", "abc")) +} diff --git a/nfs/find.go b/nfs/find.go new file mode 100644 index 0000000..7522d92 --- /dev/null +++ b/nfs/find.go @@ -0,0 +1,155 @@ +package nfs + +import ( + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/nstr" + "io/fs" + "os" + "path/filepath" +) + +// SearchNameUp find file/dir name in dirPath or parent dirs, +// return the name of directory path +// +// Usage: +// +// repoDir := nfs.SearchNameUp("/path/to/dir", ".git") +func SearchNameUp(dirPath, name string) string { + dir, _ := SearchNameUpx(dirPath, name) + return dir +} + +// SearchNameUpx find file/dir name in dirPath or parent dirs, +// return the name of directory path and dir is changed. +func SearchNameUpx(dirPath, name string) (string, bool) { + var level int + dirPath = ToAbsPath(dirPath) + + for { + namePath := filepath.Join(dirPath, name) + if PathExists(namePath) { + return dirPath, level > 0 + } + + level++ + prevLn := len(dirPath) + dirPath = filepath.Dir(dirPath) + if prevLn == len(dirPath) { + return "", false + } + } +} + +// WalkDir walks the file tree rooted at root, calling fn for each file or +// directory in the tree, including root. +func WalkDir(dir string, fn fs.WalkDirFunc) error { + return filepath.WalkDir(dir, fn) +} + +// GlobWithFunc handle matched file +// +// - TIP: will be not find in subdir. +func GlobWithFunc(pattern string, fn func(filePath string) error) (err error) { + files, err := filepath.Glob(pattern) + if err != nil { + return err + } + + for _, filePath := range files { + err = fn(filePath) + if err != nil { + break + } + } + return +} + +type ( + // FilterFunc type for FindInDir + // + // - return False will skip handle the file. + FilterFunc func(fPath string, ent fs.DirEntry) bool + + // HandleFunc type for FindInDir + HandleFunc func(fPath string, ent fs.DirEntry) error +) + +// OnlyFindDir on find +func OnlyFindDir(_ string, ent fs.DirEntry) bool { + return ent.IsDir() +} + +// OnlyFindFile on find +func OnlyFindFile(_ string, ent fs.DirEntry) bool { + return !ent.IsDir() +} + +// ExcludeNames on find +func ExcludeNames(names ...string) FilterFunc { + return func(_ string, ent fs.DirEntry) bool { + return !narr.StringsHas(names, ent.Name()) + } +} + +// IncludeSuffix on find +func IncludeSuffix(ss ...string) FilterFunc { + return func(_ string, ent fs.DirEntry) bool { + return nstr.HasOneSuffix(ent.Name(), ss) + } +} + +// ExcludeDotFile on find +func ExcludeDotFile(_ string, ent fs.DirEntry) bool { + return ent.Name()[0] != '.' +} + +// ExcludeSuffix on find +func ExcludeSuffix(ss ...string) FilterFunc { + return func(_ string, ent fs.DirEntry) bool { + return !nstr.HasOneSuffix(ent.Name(), ss) + } +} + +// ApplyFilters handle +func ApplyFilters(fPath string, ent fs.DirEntry, filters []FilterFunc) bool { + for _, filter := range filters { + if !filter(fPath, ent) { + return true + } + } + return false +} + +// FindInDir code refer the go pkg: path/filepath.glob() +// +// - TIP: will be not find in subdir. +// +// filters: return false will skip the file. +func FindInDir(dir string, handleFn HandleFunc, filters ...FilterFunc) (e error) { + fi, err := os.Stat(dir) + if err != nil || !fi.IsDir() { + return // ignore I/O error + } + + // names, _ := d.Readdirnames(-1) + // sort.Strings(names) + + des, err := os.ReadDir(dir) + if err != nil { + return + } + + for _, ent := range des { + filePath := dir + "/" + ent.Name() + + // apply filters + if len(filters) > 0 && ApplyFilters(filePath, ent, filters) { + continue + } + + if err := handleFn(filePath, ent); err != nil { + return err + } + } + return nil +} diff --git a/nfs/find_test.go b/nfs/find_test.go new file mode 100644 index 0000000..1136000 --- /dev/null +++ b/nfs/find_test.go @@ -0,0 +1,95 @@ +package nfs_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "io/fs" + "strings" + "testing" +) + +func TestSearchNameUp(t *testing.T) { + p := nfs.SearchNameUp("testdata", "finder") + assert.NotEmpty(t, p) + assert.True(t, strings.HasSuffix(p, "nfs")) + + p = nfs.SearchNameUp("testdata", ".dotdir") + assert.NotEmpty(t, p) + assert.True(t, strings.HasSuffix(p, "testdata")) + + p = nfs.SearchNameUp("testdata", "test.txt") + assert.NotEmpty(t, p) + assert.True(t, strings.HasSuffix(p, "testdata")) + + p = nfs.SearchNameUp("testdata", "not-exists") + assert.Empty(t, p) +} + +type dirEnt struct { + typ fs.FileMode + isDir bool + name string +} + +func (d *dirEnt) Name() string { + return d.name +} + +func (d *dirEnt) IsDir() bool { + return d.isDir +} + +func (d *dirEnt) Type() fs.FileMode { + return d.typ +} + +func (d *dirEnt) Info() (fs.FileInfo, error) { + panic("implement me") +} + +func TestApplyFilters(t *testing.T) { + e1 := &dirEnt{name: "some-backup"} + f1 := nfs.ExcludeSuffix("-backup") + + assert.False(t, f1("", e1)) + assert.True(t, nfs.ApplyFilters("", e1, []nfs.FilterFunc{f1})) + assert.True(t, nfs.ApplyFilters("", e1, []nfs.FilterFunc{nfs.OnlyFindDir})) + assert.False(t, nfs.ApplyFilters("", e1, []nfs.FilterFunc{nfs.OnlyFindFile})) + assert.False(t, nfs.ApplyFilters("", e1, []nfs.FilterFunc{nfs.ExcludeDotFile})) + assert.False(t, nfs.ApplyFilters("", e1, []nfs.FilterFunc{nfs.IncludeSuffix("-backup")})) + assert.True(t, nfs.ApplyFilters("", e1, []nfs.FilterFunc{nfs.ExcludeNames("some-backup")})) +} + +func TestFindInDir(t *testing.T) { + err := nfs.FindInDir("path-not-exist", nil) + assert.NoErr(t, err) + + err = nfs.FindInDir("testdata/test.txt", nil) + assert.NoErr(t, err) + + files := make([]string, 0, 8) + err = nfs.FindInDir("testdata", func(fPath string, de fs.DirEntry) error { + files = append(files, fPath) + return nil + }) + + //dump.P(files) + assert.NoErr(t, err) + assert.True(t, len(files) > 0) + + files = files[:0] + err = nfs.FindInDir("testdata", func(fPath string, de fs.DirEntry) error { + files = append(files, fPath) + return nil + }, func(fPath string, de fs.DirEntry) bool { + return !strings.HasPrefix(de.Name(), ".") + }) + assert.NoErr(t, err) + assert.True(t, len(files) > 0) + + err = nfs.FindInDir("testdata", func(fPath string, de fs.DirEntry) error { + return fmt.Errorf("handle error") + }) + assert.Err(t, err) +} diff --git a/nfs/finder/README.md b/nfs/finder/README.md new file mode 100644 index 0000000..257b797 --- /dev/null +++ b/nfs/finder/README.md @@ -0,0 +1,23 @@ +# finder + +`finder` provide a finding tool for find files or dirs, and with some built-in matchers. + +## Usage + +```go +package main + +import ( + "git.noahlan.cn/noahlan/ntool/nfs/finder" +) + +func main() { + ff := finder.NewFinder() + ff.AddScan("/tmp", "/usr/local", "/usr/local/share") + ff.ExcludeDir("abc", "def").ExcludeFile("*.log", "*.tmp") + + //ss := ff.FindPaths() + //dump.P(ss) +} +``` + diff --git a/nfs/finder/config.go b/nfs/finder/config.go new file mode 100644 index 0000000..1d81d3f --- /dev/null +++ b/nfs/finder/config.go @@ -0,0 +1,492 @@ +package finder + +import "strings" + +// commonly dot file and dirs +var ( + CommonlyDotDirs = []string{".git", ".idea", ".vscode", ".svn", ".hg"} + CommonlyDotFiles = []string{".gitignore", ".dockerignore", ".npmignore", ".DS_Store", ".env"} +) + +// FindFlag type for find result. +type FindFlag uint8 + +// flags for find result. +const ( + FlagFile FindFlag = iota + 1 // only find files(default) + FlagDir + FlagBoth = FlagFile | FlagDir +) + +// ToFlag convert string to FindFlag +func ToFlag(s string) FindFlag { + switch strings.ToLower(s) { + case "dirs", "dir", "d": + return FlagDir + case "both", "b": + return FlagBoth + default: + return FlagFile + } +} + +// Config for finder +type Config struct { + init bool + depth int + + // ScanDirs scan dir paths for find. + ScanDirs []string `json:"scan_dirs"` + // FindFlags for find result. default is FlagFile + FindFlags FindFlag `json:"find_flags"` + // MaxDepth for find result. default is 0 - not limit + MaxDepth int `json:"max_depth"` + // UseAbsPath use abs path for find result. default is false + UseAbsPath bool `json:"use_abs_path"` + // CacheResult cache result for find result. default is false + CacheResult bool `json:"cache_result"` + // ExcludeDotDir exclude dot dir. default is true + ExcludeDotDir bool `json:"exclude_dot_dir"` + // ExcludeDotFile exclude dot dir. default is false + ExcludeDotFile bool `json:"exclude_dot_file"` + + // Matchers generic include matchers for file/dir elems + Matchers []Matcher + // ExMatchers generic exclude matchers for file/dir elems + ExMatchers []Matcher + // DirMatchers include matchers for dir elems + DirMatchers []Matcher + // DirExMatchers exclude matchers for dir elems + DirExMatchers []Matcher + // FileMatchers include matchers for file elems + FileMatchers []Matcher + // FileExMatchers exclude matchers for file elems + FileExMatchers []Matcher + + // commonly settings for build matchers + + // IncludeDirs include dir name list. eg: {"model"} + IncludeDirs []string `json:"include_dirs"` + // IncludeExts include file ext name list. eg: {".go", ".md"} + IncludeExts []string `json:"include_exts"` + // IncludeFiles include file name list. eg: {"go.mod"} + IncludeFiles []string `json:"include_files"` + // IncludePaths include file/dir path list. eg: {"path/to"} + IncludePaths []string `json:"include_paths"` + // IncludeNames include file/dir name list. eg: {"test", "some.go"} + IncludeNames []string `json:"include_names"` + + // ExcludeDirs exclude dir name list. eg: {"test"} + ExcludeDirs []string `json:"exclude_dirs"` + // ExcludeExts exclude file ext name list. eg: {".go", ".md"} + ExcludeExts []string `json:"exclude_exts"` + // ExcludeFiles exclude file name list. eg: {"go.mod"} + ExcludeFiles []string `json:"exclude_files"` + // ExcludePaths exclude file/dir path list. eg: {"path/to"} + ExcludePaths []string `json:"exclude_paths"` + // ExcludeNames exclude file/dir name list. eg: {"test", "some.go"} + ExcludeNames []string `json:"exclude_names"` +} + +// NewConfig create a new Config +func NewConfig(dirs ...string) *Config { + return &Config{ + ScanDirs: dirs, + FindFlags: FlagFile, + // with default setting. + ExcludeDotDir: true, + } +} + +// NewEmptyConfig create a new Config +func NewEmptyConfig() *Config { + return &Config{FindFlags: FlagFile} +} + +// NewFinder create a new Finder by config +func (c *Config) NewFinder() *Finder { + return NewWithConfig(c.Init()) +} + +// Init build matchers by config and append to Matchers. +func (c *Config) Init() *Config { + if c.init { + return c + } + + // generic matchers + if len(c.IncludeNames) > 0 { + c.Matchers = append(c.Matchers, MatchNames(c.IncludeNames)) + } + + if len(c.IncludePaths) > 0 { + c.Matchers = append(c.Matchers, MatchPaths(c.IncludePaths)) + } + + if len(c.ExcludePaths) > 0 { + c.ExMatchers = append(c.ExMatchers, MatchPaths(c.ExcludePaths)) + } + + if len(c.ExcludeNames) > 0 { + c.ExMatchers = append(c.ExMatchers, MatchNames(c.ExcludeNames)) + } + + // dir matchers + if len(c.IncludeDirs) > 0 { + c.DirMatchers = append(c.DirMatchers, MatchNames(c.IncludeDirs)) + } + + if len(c.ExcludeDirs) > 0 { + c.DirExMatchers = append(c.DirExMatchers, MatchNames(c.ExcludeDirs)) + } + + // file matchers + if len(c.IncludeExts) > 0 { + c.FileMatchers = append(c.FileMatchers, MatchExts(c.IncludeExts)) + } + + if len(c.IncludeFiles) > 0 { + c.FileMatchers = append(c.FileMatchers, MatchNames(c.IncludeFiles)) + } + + if len(c.ExcludeExts) > 0 { + c.FileExMatchers = append(c.FileExMatchers, MatchExts(c.ExcludeExts)) + } + + if len(c.ExcludeFiles) > 0 { + c.FileExMatchers = append(c.FileExMatchers, MatchNames(c.ExcludeFiles)) + } + + return c +} + +// +// --------- config for finder --------- +// + +// WithConfig on the finder +func (f *Finder) WithConfig(c *Config) *Finder { + f.c = c + return f +} + +// ConfigFn the finder. alias of WithConfigFn() +func (f *Finder) ConfigFn(fns ...func(c *Config)) *Finder { return f.WithConfigFn(fns...) } + +// WithConfigFn the finder +func (f *Finder) WithConfigFn(fns ...func(c *Config)) *Finder { + if f.c == nil { + f.c = &Config{} + } + + for _, fn := range fns { + fn(f.c) + } + return f +} + +// AddScanDirs add source dir for find +func (f *Finder) AddScanDirs(dirPaths []string) *Finder { + f.c.ScanDirs = append(f.c.ScanDirs, dirPaths...) + return f +} + +// AddScanDir add source dir for find. alias of AddScanDirs() +func (f *Finder) AddScanDir(dirPaths ...string) *Finder { return f.AddScanDirs(dirPaths) } + +// AddScan add source dir for find. alias of AddScanDirs() +func (f *Finder) AddScan(dirPaths ...string) *Finder { return f.AddScanDirs(dirPaths) } + +// ScanDir add source dir for find. alias of AddScanDirs() +func (f *Finder) ScanDir(dirPaths ...string) *Finder { return f.AddScanDirs(dirPaths) } + +// CacheResult cache result for find result. +func (f *Finder) CacheResult(enable ...bool) *Finder { + if len(enable) > 0 { + f.c.CacheResult = enable[0] + } else { + f.c.CacheResult = true + } + return f +} + +// WithFlags set find flags. +func (f *Finder) WithFlags(flags FindFlag) *Finder { + f.c.FindFlags = flags + return f +} + +// WithStrFlag set find flags by string. +func (f *Finder) WithStrFlag(s string) *Finder { + f.c.FindFlags = ToFlag(s) + return f +} + +// OnlyFindDir only find dir. +func (f *Finder) OnlyFindDir() *Finder { return f.WithFlags(FlagDir) } + +// FileAndDir both find file and dir. +func (f *Finder) FileAndDir() *Finder { return f.WithFlags(FlagDir | FlagFile) } + +// UseAbsPath use absolute path for find result. alias of WithUseAbsPath() +func (f *Finder) UseAbsPath(enable ...bool) *Finder { return f.WithUseAbsPath(enable...) } + +// WithUseAbsPath use absolute path for find result. +func (f *Finder) WithUseAbsPath(enable ...bool) *Finder { + if len(enable) > 0 { + f.c.UseAbsPath = enable[0] + } else { + f.c.UseAbsPath = true + } + return f +} + +// WithMaxDepth set max depth for find. +func (f *Finder) WithMaxDepth(i int) *Finder { + f.c.MaxDepth = i + return f +} + +// IncludeDir include dir names. +func (f *Finder) IncludeDir(dirs ...string) *Finder { + f.c.IncludeDirs = append(f.c.IncludeDirs, dirs...) + return f +} + +// WithDirName include dir names. alias of IncludeDir() +func (f *Finder) WithDirName(dirs ...string) *Finder { return f.IncludeDir(dirs...) } + +// IncludeFile include file names. +func (f *Finder) IncludeFile(files ...string) *Finder { + f.c.IncludeFiles = append(f.c.IncludeFiles, files...) + return f +} + +// WithFileName include file names. alias of IncludeFile() +func (f *Finder) WithFileName(files ...string) *Finder { return f.IncludeFile(files...) } + +// IncludeName include file or dir names. +func (f *Finder) IncludeName(names ...string) *Finder { + f.c.IncludeNames = append(f.c.IncludeNames, names...) + return f +} + +// WithNames include file or dir names. alias of IncludeName() +func (f *Finder) WithNames(names []string) *Finder { return f.IncludeName(names...) } + +// IncludeExt include file exts. +func (f *Finder) IncludeExt(exts ...string) *Finder { + f.c.IncludeExts = append(f.c.IncludeExts, exts...) + return f +} + +// WithExts include file exts. alias of IncludeExt() +func (f *Finder) WithExts(exts []string) *Finder { return f.IncludeExt(exts...) } + +// WithFileExt include file exts. alias of IncludeExt() +func (f *Finder) WithFileExt(exts ...string) *Finder { return f.IncludeExt(exts...) } + +// IncludePath include file or dir paths. +func (f *Finder) IncludePath(paths ...string) *Finder { + f.c.IncludePaths = append(f.c.IncludePaths, paths...) + return f +} + +// WithPaths include file or dir paths. alias of IncludePath() +func (f *Finder) WithPaths(paths []string) *Finder { return f.IncludePath(paths...) } + +// WithSubPath include file or dir paths. alias of IncludePath() +func (f *Finder) WithSubPath(paths ...string) *Finder { return f.IncludePath(paths...) } + +// ExcludeDir exclude dir names. +func (f *Finder) ExcludeDir(dirs ...string) *Finder { + f.c.ExcludeDirs = append(f.c.ExcludeDirs, dirs...) + return f +} + +// WithoutDir exclude dir names. alias of ExcludeDir() +func (f *Finder) WithoutDir(dirs ...string) *Finder { return f.ExcludeDir(dirs...) } + +// WithoutNames exclude file or dir names. +func (f *Finder) WithoutNames(names []string) *Finder { + f.c.ExcludeNames = append(f.c.ExcludeNames, names...) + return f +} + +// ExcludeName exclude file names. alias of WithoutNames() +func (f *Finder) ExcludeName(names ...string) *Finder { return f.WithoutNames(names) } + +// ExcludeFile exclude file names. +func (f *Finder) ExcludeFile(files ...string) *Finder { + f.c.ExcludeFiles = append(f.c.ExcludeFiles, files...) + return f +} + +// WithoutFile exclude file names. alias of ExcludeFile() +func (f *Finder) WithoutFile(files ...string) *Finder { return f.ExcludeFile(files...) } + +// ExcludeExt exclude file exts. +// +// eg: ExcludeExt(".go", ".java") +func (f *Finder) ExcludeExt(exts ...string) *Finder { + f.c.ExcludeExts = append(f.c.ExcludeExts, exts...) + return f +} + +// WithoutExt exclude file exts. alias of ExcludeExt() +func (f *Finder) WithoutExt(exts ...string) *Finder { return f.ExcludeExt(exts...) } + +// WithoutExts exclude file exts. alias of ExcludeExt() +func (f *Finder) WithoutExts(exts []string) *Finder { return f.ExcludeExt(exts...) } + +// ExcludePath exclude file paths. +func (f *Finder) ExcludePath(paths ...string) *Finder { + f.c.ExcludePaths = append(f.c.ExcludePaths, paths...) + return f +} + +// WithoutPath exclude file paths. alias of ExcludePath() +func (f *Finder) WithoutPath(paths ...string) *Finder { return f.ExcludePath(paths...) } + +// WithoutPaths exclude file paths. alias of ExcludePath() +func (f *Finder) WithoutPaths(paths []string) *Finder { return f.ExcludePath(paths...) } + +// ExcludeDotDir exclude dot dir names. eg: ".idea" +func (f *Finder) ExcludeDotDir(exclude ...bool) *Finder { + if len(exclude) > 0 { + f.c.ExcludeDotDir = exclude[0] + } else { + f.c.ExcludeDotDir = true + } + return f +} + +// WithoutDotDir exclude dot dir names. alias of ExcludeDotDir(). +func (f *Finder) WithoutDotDir(exclude ...bool) *Finder { + return f.ExcludeDotDir(exclude...) +} + +// NoDotDir exclude dot dir names. alias of ExcludeDotDir(). +func (f *Finder) NoDotDir(exclude ...bool) *Finder { + return f.ExcludeDotDir(exclude...) +} + +// ExcludeDotFile exclude dot dir names. eg: ".gitignore" +func (f *Finder) ExcludeDotFile(exclude ...bool) *Finder { + if len(exclude) > 0 { + f.c.ExcludeDotFile = exclude[0] + } else { + f.c.ExcludeDotFile = true + } + return f +} + +// WithoutDotFile exclude dot dir names. alias of ExcludeDotFile(). +func (f *Finder) WithoutDotFile(exclude ...bool) *Finder { + return f.ExcludeDotFile(exclude...) +} + +// NoDotFile exclude dot dir names. alias of ExcludeDotFile(). +func (f *Finder) NoDotFile(exclude ...bool) *Finder { + return f.ExcludeDotFile(exclude...) +} + +// +// --------- add matchers to finder --------- +// + +// Includes add include match matchers +func (f *Finder) Includes(fls []Matcher) *Finder { + f.c.Matchers = append(f.c.Matchers, fls...) + return f +} + +// Collect add include match matchers. alias of Includes() +func (f *Finder) Collect(fls ...Matcher) *Finder { return f.Includes(fls) } + +// Include add include match matchers. alias of Includes() +func (f *Finder) Include(fls ...Matcher) *Finder { return f.Includes(fls) } + +// With add include match matchers. alias of Includes() +func (f *Finder) With(fls ...Matcher) *Finder { return f.Includes(fls) } + +// Adds include match matchers. alias of Includes() +func (f *Finder) Adds(fls []Matcher) *Finder { return f.Includes(fls) } + +// Add include match matchers. alias of Includes() +func (f *Finder) Add(fls ...Matcher) *Finder { return f.Includes(fls) } + +// Excludes add exclude match matchers +func (f *Finder) Excludes(fls []Matcher) *Finder { + f.c.ExMatchers = append(f.c.ExMatchers, fls...) + return f +} + +// Exclude add exclude match matchers. alias of Excludes() +func (f *Finder) Exclude(fls ...Matcher) *Finder { return f.Excludes(fls) } + +// Without add exclude match matchers. alias of Excludes() +func (f *Finder) Without(fls ...Matcher) *Finder { return f.Excludes(fls) } + +// Nots add exclude match matchers. alias of Excludes() +func (f *Finder) Nots(fls []Matcher) *Finder { return f.Excludes(fls) } + +// Not add exclude match matchers. alias of Excludes() +func (f *Finder) Not(fls ...Matcher) *Finder { return f.Excludes(fls) } + +// WithMatchers add include matchers +func (f *Finder) WithMatchers(fls []Matcher) *Finder { + f.c.Matchers = append(f.c.Matchers, fls...) + return f +} + +// WithFilter add include matchers +func (f *Finder) WithFilter(fls ...Matcher) *Finder { return f.WithMatchers(fls) } + +// MatchFiles add include file matchers +func (f *Finder) MatchFiles(fls []Matcher) *Finder { + f.c.FileMatchers = append(f.c.FileMatchers, fls...) + return f +} + +// MatchFile add include file matchers +func (f *Finder) MatchFile(fls ...Matcher) *Finder { return f.MatchFiles(fls) } + +// AddFiles add include file matchers +func (f *Finder) AddFiles(fls []Matcher) *Finder { return f.MatchFiles(fls) } + +// AddFile add include file matchers +func (f *Finder) AddFile(fls ...Matcher) *Finder { return f.MatchFiles(fls) } + +// NotFiles add exclude file matchers +func (f *Finder) NotFiles(fls []Matcher) *Finder { + f.c.FileExMatchers = append(f.c.FileExMatchers, fls...) + return f +} + +// NotFile add exclude file matchers +func (f *Finder) NotFile(fls ...Matcher) *Finder { return f.NotFiles(fls) } + +// MatchDirs add exclude dir matchers +func (f *Finder) MatchDirs(fls []Matcher) *Finder { + f.c.DirMatchers = append(f.c.DirMatchers, fls...) + return f +} + +// MatchDir add exclude dir matchers +func (f *Finder) MatchDir(fls ...Matcher) *Finder { return f.MatchDirs(fls) } + +// WithDirs add exclude dir matchers +func (f *Finder) WithDirs(fls []Matcher) *Finder { return f.MatchDirs(fls) } + +// WithDir add exclude dir matchers +func (f *Finder) WithDir(fls ...Matcher) *Finder { return f.MatchDirs(fls) } + +// NotDirs add exclude dir matchers +func (f *Finder) NotDirs(fls []Matcher) *Finder { + f.c.DirExMatchers = append(f.c.DirExMatchers, fls...) + return f +} + +// NotDir add exclude dir matchers +func (f *Finder) NotDir(fls ...Matcher) *Finder { return f.NotDirs(fls) } diff --git a/nfs/finder/elem.go b/nfs/finder/elem.go new file mode 100644 index 0000000..d0eec07 --- /dev/null +++ b/nfs/finder/elem.go @@ -0,0 +1,48 @@ +package finder + +import ( + "git.noahlan.cn/noahlan/ntool/nstr" + "io/fs" +) + +// Elem of find file/dir path result +type Elem interface { + fs.DirEntry + // Path get file/dir path. eg: "/path/to/file.go" + Path() string + // Info get file info. like fs.DirEntry.Info(), but will cache result. + Info() (fs.FileInfo, error) +} + +type elem struct { + fs.DirEntry + path string + stat fs.FileInfo + sErr error +} + +// NewElem create a new Elem instance +func NewElem(fPath string, ent fs.DirEntry) Elem { + return &elem{ + path: fPath, + DirEntry: ent, + } +} + +// Path get full file/dir path. eg: "/path/to/file.go" +func (e *elem) Path() string { + return e.path +} + +// Info get file info, will cache result +func (e *elem) Info() (fs.FileInfo, error) { + if e.stat == nil { + e.stat, e.sErr = e.DirEntry.Info() + } + return e.stat, e.sErr +} + +// String get string representation +func (e *elem) String() string { + return nstr.OrCond(e.IsDir(), "dir: ", "file: ") + e.Path() +} diff --git a/nfs/finder/finder.go b/nfs/finder/finder.go new file mode 100644 index 0000000..29f4e35 --- /dev/null +++ b/nfs/finder/finder.go @@ -0,0 +1,353 @@ +// Package finder provide a finding tool for find files or dirs, +// and with some built-in matchers. +package finder + +import ( + "os" + "path/filepath" + "strings" +) + +// FileFinder type alias. +type FileFinder = Finder + +// Finder struct +type Finder struct { + // config for finder + c *Config + // last error + err error + // num - founded fs elem number + num int + // ch - founded fs elem chan + ch chan Elem + // caches - cache found fs elem. if config.CacheResult is true + caches []Elem +} + +// New instance with source dir paths. +func New(dirs []string) *Finder { + c := NewConfig(dirs...) + return NewWithConfig(c) +} + +// NewFinder new instance with source dir paths. +func NewFinder(dirPaths ...string) *Finder { return New(dirPaths) } + +// NewWithConfig new instance with config. +func NewWithConfig(c *Config) *Finder { + return &Finder{c: c} +} + +// NewEmpty new empty Finder instance +func NewEmpty() *Finder { + return &Finder{c: NewEmptyConfig()} +} + +// EmptyFinder new empty Finder instance. alias of NewEmpty() +func EmptyFinder() *Finder { return NewEmpty() } + +// +// --------- do finding --------- +// + +// Find files in given dir paths. will return a channel, you can use it to get the result. +// +// Usage: +// +// f := NewFinder("/path/to/dir").Find() +// for el := range f { +// fmt.Println(el.Path()) +// } +func (f *Finder) Find() <-chan Elem { + f.find() + return f.ch +} + +// Elems find and return founded file Elem. alias of Find() +func (f *Finder) Elems() <-chan Elem { return f.Find() } + +// Results find and return founded file Elem. alias of Find() +func (f *Finder) Results() <-chan Elem { return f.Find() } + +// FindNames find and return founded file/dir names. +func (f *Finder) FindNames() []string { + paths := make([]string, 0, 8*len(f.c.ScanDirs)) + for el := range f.Find() { + paths = append(paths, el.Name()) + } + return paths +} + +// FindPaths find and return founded file/dir paths. +func (f *Finder) FindPaths() []string { + paths := make([]string, 0, 8*len(f.c.ScanDirs)) + for el := range f.Find() { + paths = append(paths, el.Path()) + } + return paths +} + +// Each founded file or dir Elem. +func (f *Finder) Each(fn func(el Elem)) { f.EachElem(fn) } + +// EachElem founded file or dir Elem. +func (f *Finder) EachElem(fn func(el Elem)) { + f.find() + for el := range f.ch { + fn(el) + } +} + +// EachPath founded file paths. +func (f *Finder) EachPath(fn func(filePath string)) { + f.EachElem(func(el Elem) { + fn(el.Path()) + }) +} + +// EachFile each file os.File +func (f *Finder) EachFile(fn func(file *os.File)) { + f.EachElem(func(el Elem) { + file, err := os.Open(el.Path()) + if err == nil { + fn(file) + } else { + f.err = err + } + }) +} + +// EachStat each file os.FileInfo +func (f *Finder) EachStat(fn func(fi os.FileInfo, filePath string)) { + f.EachElem(func(el Elem) { + fi, err := el.Info() + if err == nil { + fn(fi, el.Path()) + } else { + f.err = err + } + }) +} + +// EachContents handle each found file contents +func (f *Finder) EachContents(fn func(contents, filePath string)) { + f.EachElem(func(el Elem) { + bs, err := os.ReadFile(el.Path()) + if err == nil { + fn(string(bs), el.Path()) + } else { + f.err = err + } + }) +} + +// prepare for find. +func (f *Finder) prepare() { + f.err = nil + f.ch = make(chan Elem, 8) + + if f.CacheNum() == 0 { + f.num = 0 + } + + if f.c == nil { + f.c = NewConfig() + } else { + f.c.Init() + } +} + +// do finding +func (f *Finder) find() { + f.prepare() + + go func() { + defer close(f.ch) + + // read from caches + if f.c.CacheResult && len(f.caches) > 0 { + for _, el := range f.caches { + f.ch <- el + } + return + } + + // do finding + var err error + for _, dirPath := range f.c.ScanDirs { + if f.c.UseAbsPath { + dirPath, err = filepath.Abs(dirPath) + if err != nil { + f.err = err + continue + } + } + + f.c.depth = 0 + f.findDir(dirPath, f.c) + } + }() +} + +// code refer filepath.glob() +func (f *Finder) findDir(dirPath string, c *Config) { + des, err := os.ReadDir(dirPath) + if err != nil { + return // ignore I/O error + } + + var ok bool + c.depth++ + for _, ent := range des { + name := ent.Name() + isDir := ent.IsDir() + if name[0] == '.' { + if isDir { + if c.ExcludeDotDir { + continue + } + } else if c.ExcludeDotFile { + continue + } + } + + fullPath := filepath.Join(dirPath, name) + el := NewElem(fullPath, ent) + + // apply generic filters + if !applyExMatchers(el, c.ExMatchers) { + continue + } + + // --- dir: apply dir filters + if isDir { + if !applyExMatchers(el, c.DirExMatchers) { + continue + } + + if len(c.Matchers) > 0 { + ok = applyMatchers(el, c.Matchers) + if !ok && len(c.DirMatchers) > 0 { + ok = applyMatchers(el, c.DirMatchers) + } + } else { + ok = applyMatchers(el, c.DirMatchers) + } + + if ok && c.FindFlags&FlagDir > 0 { + if c.CacheResult { + f.caches = append(f.caches, el) + } + f.num++ + f.ch <- el + + if c.FindFlags == FlagDir { + continue // only find subdir on ok=false + } + } + + // find in sub dir. + if c.MaxDepth == 0 || c.depth < c.MaxDepth { + f.findDir(fullPath, c) + c.depth-- // restore depth + } + continue + } + + // --- type: file + if c.FindFlags&FlagFile == 0 { + continue + } + + // apply file filters + if !applyExMatchers(el, c.FileExMatchers) { + continue + } + + if len(c.Matchers) > 0 { + ok = applyMatchers(el, c.Matchers) + if !ok && len(c.FileMatchers) > 0 { + ok = applyMatchers(el, c.FileMatchers) + } + } else { + ok = applyMatchers(el, c.FileMatchers) + } + + // write to consumer + if ok && c.FindFlags&FlagFile > 0 { + if c.CacheResult { + f.caches = append(f.caches, el) + } + f.num++ + f.ch <- el + } + } +} + +func applyMatchers(el Elem, fls []Matcher) bool { + for _, f := range fls { + if f.Apply(el) { + return true + } + } + return len(fls) == 0 +} + +func applyExMatchers(el Elem, fls []Matcher) bool { + for _, f := range fls { + if f.Apply(el) { + return false + } + } + return true +} + +// Reset filters config setting and results info. +func (f *Finder) Reset() { + c := NewConfig(f.c.ScanDirs...) + c.ExcludeDotDir = f.c.ExcludeDotDir + c.FindFlags = f.c.FindFlags + c.MaxDepth = f.c.MaxDepth + + f.c = c + f.ResetResult() +} + +// ResetResult reset result info. +func (f *Finder) ResetResult() { + f.num = 0 + f.err = nil + f.ch = make(chan Elem, 8) + f.caches = []Elem{} +} + +// Num get found elem num. only valid after finding. +func (f *Finder) Num() int { + return f.num +} + +// Err get last error +func (f *Finder) Err() error { + return f.err +} + +// Caches get cached results. only valid after finding. +func (f *Finder) Caches() []Elem { + return f.caches +} + +// CacheNum get +func (f *Finder) CacheNum() int { + return len(f.caches) +} + +// Config get +func (f *Finder) Config() Config { + return *f.c +} + +// String all dir paths +func (f *Finder) String() string { + return strings.Join(f.c.ScanDirs, ";") +} diff --git a/nfs/finder/finder_test.go b/nfs/finder/finder_test.go new file mode 100644 index 0000000..786bdfd --- /dev/null +++ b/nfs/finder/finder_test.go @@ -0,0 +1,160 @@ +package finder_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/nfs/finder" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "os" + "testing" +) + +func TestMain(m *testing.M) { + _, _ = nfs.PutContents("./testdata/test.txt", "hello, in test.txt") + m.Run() +} + +func TestFinder_findFile(t *testing.T) { + f := finder.EmptyFinder(). + ScanDir("./testdata"). + NoDotFile(). + NoDotDir(). + WithoutExt(".jpg"). + CacheResult() + + assert.Nil(t, f.Err()) + assert.NotEmpty(t, f.String()) + assert.Eq(t, 0, f.CacheNum()) + + // find paths + assert.NotEmpty(t, f.FindPaths()) + assert.Gt(t, f.CacheNum(), 0) + assert.NotEmpty(t, f.Caches()) + + f.Each(func(elem finder.Elem) { + fmt.Println(elem) + }) + + t.Run("each elem", func(t *testing.T) { + f.EachElem(func(elem finder.Elem) { + fmt.Println(elem) + }) + }) + + t.Run("each file", func(t *testing.T) { + f.EachFile(func(file *os.File) { + fmt.Println(file.Name()) + }) + }) + + t.Run("each path", func(t *testing.T) { + f.EachPath(func(filePath string) { + fmt.Println(filePath) + }) + }) + + t.Run("each stat", func(t *testing.T) { + f.EachStat(func(fi os.FileInfo, filePath string) { + fmt.Println(filePath, "=>", fi.ModTime()) + }) + }) + + t.Run("reset", func(t *testing.T) { + f.Reset() + assert.Empty(t, f.Caches()) + assert.NotEmpty(t, f.FindPaths()) + + f.EachElem(func(elem finder.Elem) { + fmt.Println(elem) + }) + }) +} + +func TestFinder_OnlyFindDir(t *testing.T) { + ff := finder.NewFinder("./../../"). + OnlyFindDir(). + UseAbsPath(). + WithoutDotDir(). + WithDirName("testdata") + + ff.EachPath(func(filePath string) { + fmt.Println(filePath) + }) + assert.Gt(t, ff.Num(), 0) + assert.Eq(t, 0, ff.CacheNum()) + + t.Run("each elem", func(t *testing.T) { + ff.Each(func(elem finder.Elem) { + fmt.Println(elem) + }) + }) + + ff.ResetResult() + assert.Eq(t, 0, ff.Num()) + assert.Eq(t, 0, ff.CacheNum()) + + t.Run("max depth", func(t *testing.T) { + ff.WithMaxDepth(2) + ff.EachPath(func(filePath string) { + fmt.Println(filePath) + }) + assert.Gt(t, ff.Num(), 0) + }) +} + +func TestFileFinder_NoDotFile(t *testing.T) { + f := finder.NewEmpty(). + CacheResult(). + ScanDir("./testdata") + assert.NotEmpty(t, f.String()) + + fileName := ".env" + assert.NotEmpty(t, f.FindPaths()) + assert.Contains(t, f.FindNames(), fileName) + + f = finder.EmptyFinder(). + ScanDir("./testdata"). + NoDotFile() + assert.NotContains(t, f.FindNames(), fileName) + + t.Run("Not MatchDotFile", func(t *testing.T) { + f = finder.EmptyFinder(). + ScanDir("./testdata"). + Not(finder.MatchDotFile()) + + assert.NotContains(t, f.FindNames(), fileName) + }) +} + +func TestFileFinder_IncludeName(t *testing.T) { + f := finder.NewFinder("."). + IncludeName("elem.go"). + WithNames([]string{"not-exist.file"}) + + names := f.FindNames() + assert.Len(t, names, 1) + assert.Contains(t, names, "elem.go") + assert.NotContains(t, names, "not-exist.file") + + f.Reset() + t.Run("name in subdir", func(t *testing.T) { + f.WithFileName("test.txt") + names = f.FindNames() + assert.Len(t, names, 1) + assert.Contains(t, names, "test.txt") + }) +} + +func TestFileFinder_ExcludeName(t *testing.T) { + f := finder.NewEmpty(). + AddScanDir("."). + WithMaxDepth(1). + ExcludeName("elem.go"). + WithoutNames([]string{"config.go"}) + f.Exclude(finder.MatchSuffix("_test.go"), finder.MatchExt(".md")) + + names := f.FindNames() + fmt.Println(names) + assert.Contains(t, names, "matcher.go") + assert.NotContains(t, names, "elem.go") +} diff --git a/nfs/finder/matcher.go b/nfs/finder/matcher.go new file mode 100644 index 0000000..758316b --- /dev/null +++ b/nfs/finder/matcher.go @@ -0,0 +1,138 @@ +package finder + +import ( + "bytes" + "git.noahlan.cn/noahlan/ntool/nfs" +) + +// Matcher for match file path. +type Matcher interface { + // Apply check find elem. return False will skip this file. + Apply(elem Elem) bool +} + +// MatcherFunc for match file info, return False will skip this file +type MatcherFunc func(elem Elem) bool + +// Apply check file path. return False will skip this file. +func (fn MatcherFunc) Apply(elem Elem) bool { + return fn(elem) +} + +// ------------------ Multi matcher wrapper ------------------ + +// MultiFilter wrapper for multi matchers +type MultiFilter struct { + Before Matcher + Filters []Matcher +} + +// Add matchers +func (mf *MultiFilter) Add(fls ...Matcher) { + mf.Filters = append(mf.Filters, fls...) +} + +// Apply check file path. return False will filter this file. +func (mf *MultiFilter) Apply(el Elem) bool { + if mf.Before != nil && !mf.Before.Apply(el) { + return false + } + + for _, fl := range mf.Filters { + if !fl.Apply(el) { + return false + } + } + return true +} + +// NewDirFilters create a new dir matchers +func NewDirFilters(fls ...Matcher) *MultiFilter { + return &MultiFilter{ + Before: MatchDir, + Filters: fls, + } +} + +// NewFileFilters create a new dir matchers +func NewFileFilters(fls ...Matcher) *MultiFilter { + return &MultiFilter{ + Before: MatchFile, + Filters: fls, + } +} + +// ------------------ Body Matcher ------------------ + +// BodyFilter for filter file contents. +type BodyFilter interface { + Apply(filePath string, buf *bytes.Buffer) bool +} + +// BodyMatcherFunc for filter file contents. +type BodyMatcherFunc func(filePath string, buf *bytes.Buffer) bool + +// Apply for filter file contents. +func (fn BodyMatcherFunc) Apply(filePath string, buf *bytes.Buffer) bool { + return fn(filePath, buf) +} + +// BodyFilters multi body matchers as Matcher +type BodyFilters struct { + Filters []BodyFilter +} + +// NewBodyFilters create a new body matchers +// +// Usage: +// +// bf := finder.NewBodyFilters( +// finder.BodyMatcherFunc(func(filePath string, buf *bytes.Buffer) bool { +// // filter file contents +// return true +// }), +// ) +// +// es := finder.NewFinder('path/to/dir').Add(bf).Elems() +// for el := range es { +// fmt.Println(el.Path()) +// } +func NewBodyFilters(fls ...BodyFilter) *BodyFilters { + return &BodyFilters{ + Filters: fls, + } +} + +// AddFilter add matchers +func (mf *BodyFilters) AddFilter(fls ...BodyFilter) { + mf.Filters = append(mf.Filters, fls...) +} + +// Apply check file path. return False will filter this file. +func (mf *BodyFilters) Apply(el Elem) bool { + if el.IsDir() { + return false + } + + // read file contents + buf := bytes.NewBuffer(nil) + file, err := nfs.OpenReadFile(el.Path()) + if err != nil { + return false + } + + _, err = buf.ReadFrom(file) + if err != nil { + file.Close() + return false + } + file.Close() + + // apply matchers + for _, fl := range mf.Filters { + if !fl.Apply(el.Path(), buf) { + return false + } + } + return true +} diff --git a/nfs/finder/matchers.go b/nfs/finder/matchers.go new file mode 100644 index 0000000..b8429c4 --- /dev/null +++ b/nfs/finder/matchers.go @@ -0,0 +1,289 @@ +package finder + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nstr" + "git.noahlan.cn/noahlan/ntool/ntime" + "path" + "regexp" + "strings" + "time" +) + +// ------------------ built in filters ------------------ + +// MatchFile only allow file path. +var MatchFile = MatcherFunc(func(el Elem) bool { + return !el.IsDir() +}) + +// MatchDir only allow dir path. +var MatchDir = MatcherFunc(func(el Elem) bool { + return el.IsDir() +}) + +// StartWithDot match dot file/dir. eg: ".gitignore" +func StartWithDot() MatcherFunc { + return func(el Elem) bool { + name := el.Name() + return len(name) > 0 && name[0] == '.' + } +} + +// MatchDotFile match dot filename. eg: ".idea" +func MatchDotFile() MatcherFunc { + return func(el Elem) bool { + return !el.IsDir() && el.Name()[0] == '.' + } +} + +// MatchDotDir match dot dirname. eg: ".idea" +func MatchDotDir() MatcherFunc { + return func(el Elem) bool { + return el.IsDir() && el.Name()[0] == '.' + } +} + +// MatchExt match filepath by given file ext. +// +// Usage: +// +// f := NewFinder('path/to/dir') +// f.Add(MatchExt(".go")) +// f.Not(MatchExt(".md")) +func MatchExt(exts ...string) MatcherFunc { return MatchExts(exts) } + +// MatchExts filter filepath by given file ext. +func MatchExts(exts []string) MatcherFunc { + return func(el Elem) bool { + elExt := path.Ext(el.Name()) + for _, ext := range exts { + if ext == elExt { + return true + } + } + return false + } +} + +// MatchName match filepath by given names. +// +// Usage: +// +// f := NewFinder('path/to/dir') +// f.Not(MatchName("README.md", "*_test.go")) +func MatchName(names ...string) MatcherFunc { return MatchNames(names) } + +// MatchNames match filepath by given names. +func MatchNames(names []string) MatcherFunc { + return func(el Elem) bool { + elName := el.Name() + for _, name := range names { + if name == elName || nfs.PathMatch(name, elName) { + return true + } + } + return false + } +} + +// MatchPrefix match filepath by check given prefixes. +// +// Usage: +// +// f := NewFinder('path/to/dir') +// f.Add(finder.MatchPrefix("app_", "README")) +func MatchPrefix(prefixes ...string) MatcherFunc { return MatchPrefixes(prefixes) } + +// MatchPrefixes match filepath by check given prefixes. +func MatchPrefixes(prefixes []string) MatcherFunc { + return func(el Elem) bool { + for _, pfx := range prefixes { + if strings.HasPrefix(el.Name(), pfx) { + return true + } + } + return false + } +} + +// MatchSuffix match filepath by check path has suffixes. +// +// Usage: +// +// f := NewFinder('path/to/dir') +// f.Add(finder.MatchSuffix("util.go", "en.md")) +// f.Not(finder.MatchSuffix("_test.go", ".log")) +func MatchSuffix(suffixes ...string) MatcherFunc { return MatchSuffixes(suffixes) } + +// MatchSuffixes match filepath by check path has suffixes. +func MatchSuffixes(suffixes []string) MatcherFunc { + return func(el Elem) bool { + for _, sfx := range suffixes { + if strings.HasSuffix(el.Path(), sfx) { + return true + } + } + return false + } +} + +// MatchPath match file/dir by given sub paths. +// +// Usage: +// +// f := NewFinder('path/to/dir') +// f.Add(MatchPath("need/path")) +func MatchPath(subPaths []string) MatcherFunc { return MatchPaths(subPaths) } + +// MatchPaths match file/dir by given sub paths. +func MatchPaths(subPaths []string) MatcherFunc { + return func(el Elem) bool { + for _, subPath := range subPaths { + if strings.Contains(el.Path(), subPath) { + return true + } + } + return false + } +} + +// GlobMatch file/dir name by given patterns. +// +// Usage: +// +// f := NewFinder('path/to/dir') +// f.AddFilter(GlobMatch("*_test.go")) +func GlobMatch(patterns ...string) MatcherFunc { return GlobMatches(patterns) } + +// GlobMatches file/dir name by given patterns. +func GlobMatches(patterns []string) MatcherFunc { + return func(el Elem) bool { + for _, pattern := range patterns { + if ok, _ := path.Match(pattern, el.Name()); ok { + return true + } + } + return false + } +} + +// RegexMatch match name by given regex pattern +// +// Usage: +// +// f := NewFinder('path/to/dir') +// f.AddFilter(RegexMatch(`[A-Z]\w+`)) +func RegexMatch(pattern string) MatcherFunc { + reg := regexp.MustCompile(pattern) + + return func(el Elem) bool { + return reg.MatchString(el.Name()) + } +} + +// NameLike exclude filepath by given name match. +func NameLike(patterns ...string) MatcherFunc { return NameLikes(patterns) } + +// NameLikes filter filepath by given name match. +func NameLikes(patterns []string) MatcherFunc { + return func(el Elem) bool { + for _, pattern := range patterns { + if nstr.LikeMatch(pattern, el.Name()) { + return true + } + } + return false + } +} + +// +// ----------------- built in file info filters ----------------- +// + +// MatchMtime match file by modify time. +// +// Note: if time is zero, it will be ignored. +// +// Usage: +// +// f := NewFinder('path/to/dir') +// // -600 seconds to now(last 10 minutes) +// f.AddFile(MatchMtime(timex.NowAddSec(-600), timex.ZeroTime)) +// // before 600 seconds(before 10 minutes) +// f.AddFile(MatchMtime(timex.ZeroTime, timex.NowAddSec(-600))) +func MatchMtime(start, end time.Time) MatcherFunc { + return MatchModTime(start, end) +} + +// MatchModTime filter file by modify time. +func MatchModTime(start, end time.Time) MatcherFunc { + return func(el Elem) bool { + if el.IsDir() { + return false + } + + fi, err := el.Info() + if err != nil { + return false + } + return ntime.InRange(fi.ModTime(), start, end) + } +} + +var timeNumReg = regexp.MustCompile(`(-?\d+)`) + +// HumanModTime filter file by modify time string. +// +// Usage: +// +// f := EmptyFinder() +// f.AddFilter(HumanModTime(">10m")) // before 10 minutes +// f.AddFilter(HumanModTime("<10m")) // latest 10 minutes, to Now +func HumanModTime(expr string) MatcherFunc { + opt := &ntime.ParseRangeOpt{AutoSort: true} + // convert > to <, < to > + expr = nstr.Replaces(expr, map[string]string{">": "<", "<": ">"}) + expr = timeNumReg.ReplaceAllStringFunc(expr, func(s string) string { + if s[0] == '-' { + return s + } + return "-" + s + }) + + start, end, err := ntime.ParseRange(expr, opt) + if err != nil { + panic(err) + } + + return MatchModTime(start, end) +} + +// FileSize match file by file size. unit: byte +func FileSize(min, max uint64) MatcherFunc { return SizeRange(min, max) } + +// SizeRange match file by file size. unit: byte +func SizeRange(min, max uint64) MatcherFunc { + return func(el Elem) bool { + if el.IsDir() { + return false + } + + fi, err := el.Info() + if err != nil { + return false + } + return nmath.InUintRange(uint64(fi.Size()), min, max) + } +} + +// HumanSize match file by file size string. eg: ">1k", "<2m", "1g~3g" +func HumanSize(expr string) MatcherFunc { + min, max, err := nstr.ParseSizeRange(expr, nil) + if err != nil { + panic(err) + } + + return SizeRange(min, max) +} diff --git a/nfs/finder/matchers_test.go b/nfs/finder/matchers_test.go new file mode 100644 index 0000000..cbbbf1c --- /dev/null +++ b/nfs/finder/matchers_test.go @@ -0,0 +1,84 @@ +package finder_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs/finder" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "git.noahlan.cn/noahlan/ntool/ntest/mock" + "testing" +) + +func newMockElem(fp string, isDir ...bool) finder.Elem { + return finder.NewElem(fp, mock.NewDirEnt(fp, isDir...)) +} + +func TestFilters_simple(t *testing.T) { + el := newMockElem("path/some.txt") + fn := finder.MatcherFunc(func(el finder.Elem) bool { + return false + }) + + assert.False(t, fn(el)) + + // match name + fn = finder.MatchName("some.txt") + assert.True(t, fn(el)) + fn = finder.MatchName("not-exist.txt") + assert.False(t, fn(el)) + + // MatchExt + fn = finder.MatchExt(".txt") + assert.True(t, fn(el)) + fn = finder.MatchExt(".js") + assert.False(t, fn(el)) + + // MatchSuffix + fn = finder.MatchSuffix("me.txt") + assert.True(t, fn(el)) + fn = finder.MatchSuffix("not-exist.txt") + assert.False(t, fn(el)) +} + +func TestRegexMatch(t *testing.T) { + tests := []struct { + filePath string + pattern string + match bool + }{ + {"path/to/util.go", `\.go$`, true}, + {"path/to/util.go", `\.md$`, false}, + {"path/to/util.md", `\.md$`, true}, + {"path/to/util.md", `\.go$`, false}, + } + + for _, tt := range tests { + el := newMockElem(tt.filePath) + fn := finder.RegexMatch(tt.pattern) + assert.Eq(t, tt.match, fn(el)) + } +} + +func TestMatchDotDir(t *testing.T) { + f := finder.EmptyFinder(). + WithFlags(finder.FlagBoth). + ScanDir("./testdata") + + dirName := ".dotdir" + assert.Contains(t, f.FindNames(), dirName) + + t.Run("NoDotDir", func(t *testing.T) { + f = finder.EmptyFinder(). + ScanDir("./testdata"). + NoDotDir() + + assert.NotContains(t, f.FindNames(), dirName) + }) + + t.Run("Exclude false", func(t *testing.T) { + f = finder.NewEmpty(). + WithStrFlag("dir"). + ScanDir("./testdata"). + ExcludeDotDir(false) + + assert.Contains(t, f.FindNames(), dirName) + }) +} diff --git a/nfs/finder/testdata/.dotdir/some.txt b/nfs/finder/testdata/.dotdir/some.txt new file mode 100644 index 0000000..e69de29 diff --git a/nfs/finder/testdata/.env b/nfs/finder/testdata/.env new file mode 100644 index 0000000..e69de29 diff --git a/nfs/finder/testdata/test.txt b/nfs/finder/testdata/test.txt new file mode 100644 index 0000000..cf4002c --- /dev/null +++ b/nfs/finder/testdata/test.txt @@ -0,0 +1 @@ +hello, in test.txt \ No newline at end of file diff --git a/nfs/fn.go b/nfs/fn.go new file mode 100644 index 0000000..9b334fa --- /dev/null +++ b/nfs/fn.go @@ -0,0 +1,7 @@ +package nfs + +import "os" + +// CloseOnExec makes sure closing the file on process forking. +func CloseOnExec(file *os.File) { +} diff --git a/nfs/info.go b/nfs/info.go new file mode 100644 index 0000000..00995a2 --- /dev/null +++ b/nfs/info.go @@ -0,0 +1,72 @@ +package nfs + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "os" + "path" + "path/filepath" +) + +// Dir get dir path from filepath, without last name. +func Dir(fpath string) string { return filepath.Dir(fpath) } + +// PathName get file/dir name from full path +func PathName(fpath string) string { return path.Base(fpath) } + +// Name get file/dir name from full path. +// +// eg: path/to/main.go => main.go +func Name(fpath string) string { + if fpath == "" { + return "" + } + return filepath.Base(fpath) +} + +// FileExt get filename ext. alias of path.Ext() +// +// eg: path/to/main.go => ".go" +func FileExt(fpath string) string { return path.Ext(fpath) } + +// Ext get filename ext. alias of path.Ext() +// +// eg: path/to/main.go => ".go" +func Ext(fpath string) string { + return path.Ext(fpath) +} + +// ExtName get filename ext. alias of path.Ext() +// +// eg: path/to/main.go => "go" +func ExtName(fpath string) string { + if ext := path.Ext(fpath); len(ext) > 0 { + return ext[1:] + } + return "" +} + +// Suffix get filename ext. alias of path.Ext() +// +// eg: path/to/main.go => ".go" +func Suffix(fpath string) string { return path.Ext(fpath) } + +// Expand will parse first `~` as user home dir path. +func Expand(pathStr string) string { + return common.ExpandHome(pathStr) +} + +// ExpandPath will parse `~` as user home dir path. +func ExpandPath(pathStr string) string { + return common.ExpandHome(pathStr) +} + +// ResolvePath will parse `~` and env var in path +func ResolvePath(pathStr string) string { + pathStr = common.ExpandHome(pathStr) + return os.ExpandEnv(pathStr) +} + +// SplitPath splits path immediately following the final Separator, separating it into a directory and file name component +func SplitPath(pathStr string) (dir, name string) { + return filepath.Split(pathStr) +} diff --git a/nfs/info_nonwin.go b/nfs/info_nonwin.go new file mode 100644 index 0000000..e7fe656 --- /dev/null +++ b/nfs/info_nonwin.go @@ -0,0 +1,18 @@ +//go:build !windows + +package nfs + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "path" +) + +// Realpath returns the shortest path name equivalent to path by purely lexical processing. +func Realpath(pathStr string) string { + pathStr = common.ExpandHome(pathStr) + + if !IsAbsPath(pathStr) { + pathStr = JoinSubPaths(common.Workdir(), pathStr) + } + return path.Clean(pathStr) +} diff --git a/nfs/info_test.go b/nfs/info_test.go new file mode 100644 index 0000000..4ec81ed --- /dev/null +++ b/nfs/info_test.go @@ -0,0 +1,18 @@ +package nfs_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestExpandPath(t *testing.T) { + path := "~/.kite" + + assert.NotEq(t, path, nfs.Expand(path)) + assert.NotEq(t, path, nfs.ExpandPath(path)) + assert.NotEq(t, path, nfs.ResolvePath(path)) + + assert.Eq(t, "", nfs.Expand("")) + assert.Eq(t, "/path/to", nfs.Expand("/path/to")) +} diff --git a/nfs/info_windows.go b/nfs/info_windows.go new file mode 100644 index 0000000..508412b --- /dev/null +++ b/nfs/info_windows.go @@ -0,0 +1,18 @@ +//go:build windows + +package nfs + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "path/filepath" +) + +// Realpath returns the shortest path name equivalent to path by purely lexical processing. +func Realpath(pathStr string) string { + pathStr = common.ExpandHome(pathStr) + + if !IsAbsPath(pathStr) { + pathStr = JoinSubPaths(common.Workdir(), pathStr) + } + return filepath.Clean(pathStr) +} diff --git a/nfs/oper.go b/nfs/oper.go new file mode 100644 index 0000000..f8c353c --- /dev/null +++ b/nfs/oper.go @@ -0,0 +1,255 @@ +package nfs + +import ( + "archive/zip" + "fmt" + "git.noahlan.cn/noahlan/ntool/ngo" + "io" + "io/fs" + "os" + "path" + "path/filepath" + "strings" +) + +// Mkdir alias of os.MkdirAll() +func Mkdir(dirPath string, perm os.FileMode) error { + return os.MkdirAll(dirPath, perm) +} + +// MkDirs batch make multi dirs at once +func MkDirs(perm os.FileMode, dirPaths ...string) error { + for _, dirPath := range dirPaths { + if err := os.MkdirAll(dirPath, perm); err != nil { + return err + } + } + return nil +} + +// MkSubDirs batch make multi sub-dirs at once +func MkSubDirs(perm os.FileMode, parentDir string, subDirs ...string) error { + for _, dirName := range subDirs { + dirPath := parentDir + "/" + dirName + if err := os.MkdirAll(dirPath, perm); err != nil { + return err + } + } + return nil +} + +// MkParentDir quick create parent dir +func MkParentDir(fpath string) error { + dirPath := filepath.Dir(fpath) + if !IsDir(dirPath) { + return os.MkdirAll(dirPath, 0775) + } + return nil +} + +// ************************************************************ +// open/create files +// ************************************************************ + +// some commonly flag const for open file +const ( + FsCWAFlags = os.O_CREATE | os.O_WRONLY | os.O_APPEND // create, append write-only + FsCWTFlags = os.O_CREATE | os.O_WRONLY | os.O_TRUNC // create, override write-only + FsCWFlags = os.O_CREATE | os.O_WRONLY // create, write-only + FsRFlags = os.O_RDONLY // read-only +) + +// OpenFile like os.OpenFile, but will auto create dir. +// +// Usage: +// +// file, err := OpenFile("path/to/file.txt", FsCWFlags, 0666) +func OpenFile(filepath string, flag int, perm os.FileMode) (*os.File, error) { + fileDir := path.Dir(filepath) + if err := os.MkdirAll(fileDir, DefaultDirPerm); err != nil { + return nil, err + } + + file, err := os.OpenFile(filepath, flag, perm) + if err != nil { + return nil, err + } + return file, nil +} + +// MustOpenFile like os.OpenFile, but will auto create dir. +// +// Usage: +// +// file := MustOpenFile("path/to/file.txt", FsCWFlags, 0666) +func MustOpenFile(filepath string, flag int, perm os.FileMode) *os.File { + file, err := OpenFile(filepath, flag, perm) + if err != nil { + panic(err) + } + return file +} + +// QuickOpenFile like os.OpenFile, open for append write. if not exists, will create it. +// +// Alias of OpenAppendFile() +func QuickOpenFile(filepath string, fileFlag ...int) (*os.File, error) { + flag := ngo.FirstOr(fileFlag, FsCWAFlags) + return OpenFile(filepath, flag, DefaultFilePerm) +} + +// OpenAppendFile like os.OpenFile, open for append write. if not exists, will create it. +func OpenAppendFile(filepath string, filePerm ...os.FileMode) (*os.File, error) { + perm := ngo.FirstOr(filePerm, DefaultFilePerm) + return OpenFile(filepath, FsCWAFlags, perm) +} + +// OpenTruncFile like os.OpenFile, open for override write. if not exists, will create it. +func OpenTruncFile(filepath string, filePerm ...os.FileMode) (*os.File, error) { + perm := ngo.FirstOr(filePerm, DefaultFilePerm) + return OpenFile(filepath, FsCWTFlags, perm) +} + +// OpenReadFile like os.OpenFile, open file for read contents +func OpenReadFile(filepath string) (*os.File, error) { + return os.OpenFile(filepath, FsRFlags, OnlyReadFilePerm) +} + +// CreateFile create file if not exists +// +// Usage: +// +// CreateFile("path/to/file.txt", 0664, 0666) +func CreateFile(fpath string, filePerm, dirPerm os.FileMode, fileFlag ...int) (*os.File, error) { + dirPath := path.Dir(fpath) + if !IsDir(dirPath) { + err := os.MkdirAll(dirPath, dirPerm) + if err != nil { + return nil, err + } + } + + flag := ngo.FirstOr(fileFlag, FsCWAFlags) + return os.OpenFile(fpath, flag, filePerm) +} + +// MustCreateFile create file, will panic on error +func MustCreateFile(filePath string, filePerm, dirPerm os.FileMode) *os.File { + file, err := CreateFile(filePath, filePerm, dirPerm) + if err != nil { + panic(err) + } + return file +} + +// ************************************************************ +// remove files +// ************************************************************ + +// Remove removes the named file or (empty) directory. +func Remove(fPath string) error { + return os.Remove(fPath) +} + +// MustRemove removes the named file or (empty) directory. +// NOTICE: will panic on error +func MustRemove(fPath string) { + if err := os.Remove(fPath); err != nil { + panic(err) + } +} + +// QuietRemove removes the named file or (empty) directory. +// +// NOTICE: will ignore error +func QuietRemove(fPath string) { _ = os.Remove(fPath) } + +// RmIfExist removes the named file or (empty) directory on exists. +func RmIfExist(fPath string) error { return DeleteIfExist(fPath) } + +// DeleteIfExist removes the named file or (empty) directory on exists. +func DeleteIfExist(fPath string) error { + if PathExists(fPath) { + return os.Remove(fPath) + } + return nil +} + +// RmFileIfExist removes the named file on exists. +func RmFileIfExist(fPath string) error { return DeleteIfFileExist(fPath) } + +// DeleteIfFileExist removes the named file on exists. +func DeleteIfFileExist(fPath string) error { + if IsFile(fPath) { + return os.Remove(fPath) + } + return nil +} + +// RemoveSub removes all sub files and dirs of dirPath, but not remove dirPath. +func RemoveSub(dirPath string, fns ...FilterFunc) error { + return FindInDir(dirPath, func(fPath string, ent fs.DirEntry) error { + if ent.IsDir() { + if err := RemoveSub(fPath, fns...); err != nil { + return err + } + } + return os.Remove(fPath) + }, fns...) +} + +// ************************************************************ +// other operates +// ************************************************************ + +// Unzip a zip archive +// from https://blog.csdn.net/wangshubo1989/article/details/71743374 +func Unzip(archive, targetDir string) (err error) { + reader, err := zip.OpenReader(archive) + if err != nil { + return err + } + + if err = os.MkdirAll(targetDir, DefaultDirPerm); err != nil { + return + } + + for _, file := range reader.File { + if strings.Contains(file.Name, "..") { + return fmt.Errorf("illegal file path in zip: %v", file.Name) + } + + fullPath := filepath.Join(targetDir, file.Name) + + if file.FileInfo().IsDir() { + err = os.MkdirAll(fullPath, file.Mode()) + if err != nil { + return err + } + continue + } + + fileReader, err := file.Open() + if err != nil { + return err + } + + targetFile, err := os.OpenFile(fullPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) + if err != nil { + _ = fileReader.Close() + return err + } + + _, err = io.Copy(targetFile, fileReader) + + // close all + _ = fileReader.Close() + targetFile.Close() + + if err != nil { + return err + } + } + + return +} diff --git a/nfs/oper_read.go b/nfs/oper_read.go new file mode 100644 index 0000000..635928c --- /dev/null +++ b/nfs/oper_read.go @@ -0,0 +1,137 @@ +package nfs + +import ( + "bufio" + "errors" + "io" + "os" + "text/scanner" +) + +// NewIOReader instance by input file path or io.Reader +func NewIOReader(in any) (r io.Reader, err error) { + switch typIn := in.(type) { + case string: // as file path + return OpenReadFile(typIn) + case io.Reader: + return typIn, nil + } + return nil, errors.New("invalid input type, allow: string, io.Reader") +} + +// DiscardReader anything from the reader +func DiscardReader(src io.Reader) { + _, _ = io.Copy(io.Discard, src) +} + +// ReadFile read file contents, will panic on error +func ReadFile(filePath string) []byte { + return MustReadFile(filePath) +} + +// MustReadFile read file contents, will panic on error +func MustReadFile(filePath string) []byte { + bs, err := os.ReadFile(filePath) + if err != nil { + panic(err) + } + return bs +} + +// ReadReader read contents from io.Reader, will panic on error +func ReadReader(r io.Reader) []byte { return MustReadReader(r) } + +// MustReadReader read contents from io.Reader, will panic on error +func MustReadReader(r io.Reader) []byte { + bs, err := io.ReadAll(r) + if err != nil { + panic(err) + } + return bs +} + +// ReadString read contents from path or io.Reader, will panic on in type error +func ReadString(in any) string { + return string(GetContents(in)) +} + +// ReadStringOrErr read contents from path or io.Reader, will panic on in type error +func ReadStringOrErr(in any) (string, error) { + r, err := NewIOReader(in) + if err != nil { + return "", err + } + + bs, err := io.ReadAll(r) + if err != nil { + return "", err + } + return string(bs), nil +} + +// ReadAll read contents from path or io.Reader, will panic on in type error +func ReadAll(in any) []byte { return GetContents(in) } + +// GetContents read contents from path or io.Reader, will panic on in type error +func GetContents(in any) []byte { + r, err := NewIOReader(in) + if err != nil { + panic(err) + } + return MustReadReader(r) +} + +// ReadOrErr read contents from path or io.Reader, will panic on in type error +func ReadOrErr(in any) ([]byte, error) { + r, err := NewIOReader(in) + if err != nil { + return nil, err + } + return io.ReadAll(r) +} + +// ReadExistFile read file contents if existed, will panic on error +func ReadExistFile(filePath string) []byte { + if IsFile(filePath) { + bs, err := os.ReadFile(filePath) + if err != nil { + panic(err) + } + return bs + } + return nil +} + +// TextScanner from filepath or io.Reader, will panic on in type error +// +// Usage: +// +// s := fsutil.TextScanner("/path/to/file") +// for tok := s.Scan(); tok != scanner.EOF; tok = s.Scan() { +// fmt.Printf("%s: %s\n", s.Position, s.TokenText()) +// } +func TextScanner(in any) *scanner.Scanner { + var s scanner.Scanner + r, err := NewIOReader(in) + if err != nil { + panic(err) + } + + s.Init(r) + s.Filename = "text-scanner" + return &s +} + +// LineScanner create from filepath or io.Reader +// +// s := fsutil.LineScanner("/path/to/file") +// for s.Scan() { +// fmt.Println(s.Text()) +// } +func LineScanner(in any) *bufio.Scanner { + r, err := NewIOReader(in) + if err != nil { + panic(err) + } + return bufio.NewScanner(r) +} diff --git a/nfs/oper_read_test.go b/nfs/oper_read_test.go new file mode 100644 index 0000000..4e20177 --- /dev/null +++ b/nfs/oper_read_test.go @@ -0,0 +1,43 @@ +package nfs_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "strings" + "testing" +) + +func TestDiscardReader(t *testing.T) { + sr := strings.NewReader("hello") + bs, err := nfs.ReadOrErr(sr) + assert.NoErr(t, err) + assert.Eq(t, []byte("hello"), bs) + + sr = strings.NewReader("hello") + assert.Eq(t, []byte("hello"), nfs.GetContents(sr)) + + sr = strings.NewReader("hello") + nfs.DiscardReader(sr) + + assert.Empty(t, nfs.ReadReader(sr)) + assert.Empty(t, nfs.ReadAll(sr)) + +} + +func TestGetContents(t *testing.T) { + fpath := "./testdata/get-contents.txt" + assert.NoErr(t, nfs.RmFileIfExist(fpath)) + + _, err := nfs.PutContents(fpath, "hello") + assert.NoErr(t, err) + + assert.Nil(t, nfs.ReadExistFile("/path-not-exist")) + assert.Eq(t, []byte("hello"), nfs.ReadExistFile(fpath)) + + assert.Panics(t, func() { + nfs.GetContents(45) + }) + assert.Panics(t, func() { + nfs.ReadFile("/path-not-exist") + }) +} diff --git a/nfs/oper_test.go b/nfs/oper_test.go new file mode 100644 index 0000000..5f56f8e --- /dev/null +++ b/nfs/oper_test.go @@ -0,0 +1,111 @@ +package nfs_test + +import ( + "git.noahlan.cn/noahlan/ntool/nenv" + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "os" + "testing" +) + +func TestMkdir(t *testing.T) { + // TODO windows will error + if nenv.IsWin() { + t.Skip("skip mkdir test on Windows") + return + } + + err := os.Chmod("./testdata", os.ModePerm) + + if assert.NoErr(t, err) { + assert.NoErr(t, nfs.Mkdir("./testdata/sub/sub21", os.ModePerm)) + assert.NoErr(t, nfs.Mkdir("./testdata/sub/sub22", 0666)) + // 066X will error + assert.NoErr(t, nfs.Mkdir("./testdata/sub/sub23/sub31", 0777)) + + assert.NoErr(t, nfs.MkParentDir("./testdata/sub/sub24/sub32")) + assert.True(t, nfs.IsDir("./testdata/sub/sub24")) + + assert.NoErr(t, os.RemoveAll("./testdata/sub")) + } +} + +func TestCreateFile(t *testing.T) { + // TODO windows will error + // if envutil.IsWin() { + // return + // } + + file, err := nfs.CreateFile("./testdata/test.txt", 0664, 0666) + if assert.NoErr(t, err) { + assert.Eq(t, "./testdata/test.txt", file.Name()) + assert.NoErr(t, file.Close()) + assert.NoErr(t, os.Remove(file.Name())) + } + + file, err = nfs.CreateFile("./testdata/sub/test.txt", 0664, 0777) + if assert.NoErr(t, err) { + assert.Eq(t, "./testdata/sub/test.txt", file.Name()) + assert.NoErr(t, file.Close()) + assert.NoErr(t, os.RemoveAll("./testdata/sub")) + } + + file, err = nfs.CreateFile("./testdata/sub/sub2/test.txt", 0664, 0777) + if assert.NoErr(t, err) { + assert.Eq(t, "./testdata/sub/sub2/test.txt", file.Name()) + assert.NoErr(t, file.Close()) + assert.NoErr(t, os.RemoveAll("./testdata/sub")) + } + + fpath := "./testdata/sub/sub3/test-must-create.txt" + assert.NoErr(t, nfs.RmFileIfExist(fpath)) + file = nfs.MustCreateFile(fpath, 0, 0766) + assert.NoErr(t, file.Close()) + + err = nfs.RemoveSub("./testdata/sub") + assert.NoErr(t, err) +} + +func TestQuickOpenFile(t *testing.T) { + fpath := "./testdata/quick-open-file.txt" + assert.NoErr(t, nfs.RmFileIfExist(fpath)) + + file, err := nfs.QuickOpenFile(fpath) + assert.NoErr(t, err) + assert.Eq(t, fpath, file.Name()) + + _, err = file.WriteString("hello") + assert.NoErr(t, err) + + // close + assert.NoErr(t, file.Close()) + + // open for read + file, err = nfs.OpenReadFile(fpath) + assert.NoErr(t, err) + // var bts [5]byte + bts := make([]byte, 5) + _, err = file.Read(bts) + assert.NoErr(t, err) + assert.Eq(t, "hello", string(bts)) + + // close + assert.NoErr(t, file.Close()) + assert.NoErr(t, nfs.Remove(file.Name())) +} + +func TestMustRemove(t *testing.T) { + assert.Panics(t, func() { + nfs.MustRemove("/path-not-exist") + }) +} + +func TestQuietRemove(t *testing.T) { + assert.NotPanics(t, func() { + nfs.QuietRemove("/path-not-exist") + }) +} + +func TestUnzip(t *testing.T) { + assert.Err(t, nfs.Unzip("/path-not-exists", "")) +} diff --git a/nfs/oper_write.go b/nfs/oper_write.go new file mode 100644 index 0000000..c19c661 --- /dev/null +++ b/nfs/oper_write.go @@ -0,0 +1,100 @@ +package nfs + +import ( + "git.noahlan.cn/noahlan/ntool/ngo" + "io" + "os" +) + +// ************************************************************ +// write, copy files +// ************************************************************ + +// PutContents create file and write contents to file at once. +// +// data type allow: string, []byte, io.Reader +// +// Tip: file flag default is FsCWTFlags (override write) +// +// Usage: +// +// nfs.PutContents(filePath, contents, nfs.FsCWAFlags) // append write +func PutContents(filePath string, data any, fileFlag ...int) (int, error) { + f, err := QuickOpenFile(filePath, ngo.FirstOr(fileFlag, FsCWTFlags)) + if err != nil { + return 0, err + } + + return WriteOSFile(f, data) +} + +// WriteFile create file and write contents to file, can set perm for file. +// +// data type allow: string, []byte, io.Reader +// +// Tip: file flag default is FsCWTFlags (override write) +// +// Usage: +// +// nfs.WriteFile(filePath, contents, nfs.DefaultFilePerm, nfs.FsCWAFlags) +func WriteFile(filePath string, data any, perm os.FileMode, fileFlag ...int) error { + flag := ngo.FirstOr(fileFlag, FsCWTFlags) + f, err := OpenFile(filePath, flag, perm) + if err != nil { + return err + } + + _, err = WriteOSFile(f, data) + return err +} + +// WriteOSFile write data to give os.File, then close file. +// +// data type allow: string, []byte, io.Reader +func WriteOSFile(f *os.File, data any) (n int, err error) { + switch typData := data.(type) { + case []byte: + n, err = f.Write(typData) + case string: + n, err = f.WriteString(typData) + case io.Reader: // eg: buffer + var n64 int64 + n64, err = io.Copy(f, typData) + n = int(n64) + default: + _ = f.Close() + panic("WriteFile: data type only allow: []byte, string, io.Reader") + } + + if err1 := f.Close(); err1 != nil && err == nil { + err = err1 + } + return n, err +} + +// CopyFile copy a file to another file path. +func CopyFile(srcPath, dstPath string) error { + srcFile, err := os.OpenFile(srcPath, FsRFlags, 0) + if err != nil { + return err + } + defer srcFile.Close() + + // create and open file + dstFile, err := QuickOpenFile(dstPath, FsCWTFlags) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.Copy(dstFile, srcFile) + return err +} + +// MustCopyFile copy file to another path. +func MustCopyFile(srcPath, dstPath string) { + err := CopyFile(srcPath, dstPath) + if err != nil { + panic(err) + } +} diff --git a/nfs/oper_write_test.go b/nfs/oper_write_test.go new file mode 100644 index 0000000..7741164 --- /dev/null +++ b/nfs/oper_write_test.go @@ -0,0 +1,26 @@ +package nfs_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestMustCopyFile(t *testing.T) { + srcPath := "./testdata/cp-file-src.txt" + dstPath := "./testdata/cp-file-dst.txt" + + assert.NoErr(t, nfs.RmIfExist(srcPath)) + assert.NoErr(t, nfs.RmFileIfExist(dstPath)) + + _, err := nfs.PutContents(srcPath, "hello") + assert.NoErr(t, err) + + nfs.MustCopyFile(srcPath, dstPath) + assert.Eq(t, []byte("hello"), nfs.GetContents(dstPath)) + assert.Eq(t, "hello", nfs.ReadString(dstPath)) + + str, err := nfs.ReadStringOrErr(dstPath) + assert.NoErr(t, err) + assert.Eq(t, "hello", str) +} diff --git a/nfs/testdata/.dotdir/some.txt b/nfs/testdata/.dotdir/some.txt new file mode 100644 index 0000000..e69de29 diff --git a/nfs/testdata/.env b/nfs/testdata/.env new file mode 100644 index 0000000..e69de29 diff --git a/nfs/testdata/cp-file-dst.txt b/nfs/testdata/cp-file-dst.txt new file mode 100644 index 0000000..b6fc4c6 --- /dev/null +++ b/nfs/testdata/cp-file-dst.txt @@ -0,0 +1 @@ +hello \ No newline at end of file diff --git a/nfs/testdata/cp-file-src.txt b/nfs/testdata/cp-file-src.txt new file mode 100644 index 0000000..b6fc4c6 --- /dev/null +++ b/nfs/testdata/cp-file-src.txt @@ -0,0 +1 @@ +hello \ No newline at end of file diff --git a/nfs/testdata/get-contents.txt b/nfs/testdata/get-contents.txt new file mode 100644 index 0000000..b6fc4c6 --- /dev/null +++ b/nfs/testdata/get-contents.txt @@ -0,0 +1 @@ +hello \ No newline at end of file diff --git a/nfs/testdata/mimetext.txt b/nfs/testdata/mimetext.txt new file mode 100644 index 0000000..69d29d3 --- /dev/null +++ b/nfs/testdata/mimetext.txt @@ -0,0 +1 @@ +package testdata diff --git a/nfs/util.go b/nfs/util.go new file mode 100644 index 0000000..97bf961 --- /dev/null +++ b/nfs/util.go @@ -0,0 +1,155 @@ +package nfs + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "git.noahlan.cn/noahlan/ntool/ncrypt" + "io" + "net/http" + "os" + "path/filepath" + "strings" +) + +const ( + // MimeSniffLen sniff Length, use for detect file mime type + MimeSniffLen = 512 +) + +// OSTempFile create a temp file on os.TempDir() +// The file is kept as open, the caller should close the file handle, and remove the file by name. +// +// Usage: +// +// nfs.OSTempFile("example.*.txt") +func OSTempFile(pattern string) (*os.File, error) { + return os.CreateTemp(os.TempDir(), pattern) +} + +// OSTempFileWithContent create a temp file on os.TempDir() with the given content +func OSTempFileWithContent(content string) (*os.File, error) { + tmp, err := OSTempFile(ncrypt.Md5String(content)) + if err != nil { + return nil, err + } + if err = os.WriteFile(tmp.Name(), []byte(content), os.ModeTemporary); err != nil { + return nil, err + } + return tmp, nil +} + +// OSTempFilenameWithContent create a temp file on os.TempDir() with the given content and returns the filename (full path). +// The caller should remove the file by name after use. +func OSTempFilenameWithContent(content string) (string, error) { + tmp, err := OSTempFileWithContent(content) + if err != nil { + return "", err + } + filename := tmp.Name() + if err = tmp.Close(); err != nil { + return "", err + } + return filename, nil +} + +// TempFile is like os.CreateTemp, but can custom temp dir. +// +// Usage: +// +// nfs.TempFile("", "example.*.txt") +func TempFile(dir, pattern string) (*os.File, error) { + return os.CreateTemp(dir, pattern) +} + +// OSTempDir creates a new temp dir on os.TempDir and return the temp dir path +// +// Usage: +// +// nfs.OSTempDir("example.*") +func OSTempDir(pattern string) (string, error) { + return os.MkdirTemp(os.TempDir(), pattern) +} + +// TempDir creates a new temp dir and return the temp dir path +// +// Usage: +// +// nfs.TempDir("", "example.*") +// nfs.TempDir("testdata", "example.*") +func TempDir(dir, pattern string) (string, error) { + return os.MkdirTemp(dir, pattern) +} + +// MimeType get File Mime Type name. eg "image/png" +func MimeType(path string) (mime string) { + file, err := os.Open(path) + if err != nil { + return + } + + return ReaderMimeType(file) +} + +// ReaderMimeType get the io.Reader mimeType +// +// Usage: +// +// file, err := os.Open(filepath) +// if err != nil { +// return +// } +// mime := ReaderMimeType(file) +func ReaderMimeType(r io.Reader) (mime string) { + var buf [MimeSniffLen]byte + n, _ := io.ReadFull(r, buf[:]) + if n == 0 { + return "" + } + + return http.DetectContentType(buf[:n]) +} + +// JoinPaths elements, alias of filepath.Join() +func JoinPaths(elem ...string) string { + return filepath.Join(elem...) +} + +// JoinSubPaths elements, like the filepath.Join() +func JoinSubPaths(basePath string, elem ...string) string { + paths := make([]string, len(elem)+1) + paths[0] = basePath + copy(paths[1:], elem) + return filepath.Join(paths...) +} + +// SlashPath alias of filepath.ToSlash +func SlashPath(path string) string { + return filepath.ToSlash(path) +} + +// UnixPath like of filepath.ToSlash, but always replace +func UnixPath(path string) string { + if !strings.ContainsRune(path, '\\') { + return path + } + return strings.ReplaceAll(path, "\\", "/") +} + +// ToAbsPath convert process. will expand home dir +// +// TIP: will don't check path +func ToAbsPath(p string) string { + if len(p) == 0 || IsAbsPath(p) { + return p + } + + // expand home dir + if p[0] == '~' { + return common.ExpandHome(p) + } + + wd, err := os.Getwd() + if err != nil { + return p + } + return filepath.Join(wd, p) +} diff --git a/nfs/util_nonwin.go b/nfs/util_nonwin.go new file mode 100644 index 0000000..85182b0 --- /dev/null +++ b/nfs/util_nonwin.go @@ -0,0 +1,16 @@ +//go:build !windows +// +build !windows + +package nfs + +import ( + "os" + "syscall" +) + +// CloseOnExec makes sure closing the file on process forking. +func CloseOnExec(file *os.File) { + if file != nil { + syscall.CloseOnExec(int(file.Fd())) + } +} diff --git a/nfs/util_nonwin_test.go b/nfs/util_nonwin_test.go new file mode 100644 index 0000000..c4000bf --- /dev/null +++ b/nfs/util_nonwin_test.go @@ -0,0 +1,19 @@ +//go:build !windows + +package nfs_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestSlashPath_nw(t *testing.T) { + assert.Eq(t, "path/to/dir", nfs.JoinPaths("path", "to", "dir")) + assert.Eq(t, "path/to/dir", nfs.JoinSubPaths("path", "to", "dir")) +} + +func TestRealpath_nw(t *testing.T) { + inPath := "/path/to/some/../dir" + assert.Eq(t, "/path/to/dir", nfs.Realpath(inPath)) +} diff --git a/nfs/util_test.go b/nfs/util_test.go new file mode 100644 index 0000000..fda82d7 --- /dev/null +++ b/nfs/util_test.go @@ -0,0 +1,59 @@ +package nfs_test + +import ( + "bytes" + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestMimeType(t *testing.T) { + assert.Eq(t, "", nfs.MimeType("")) + assert.Eq(t, "", nfs.MimeType("not-exist")) + assert.Eq(t, "text/plain; charset=utf-8", nfs.MimeType("testdata/mimetext.txt")) + + buf := new(bytes.Buffer) + buf.Write([]byte("\xFF\xD8\xFF")) + assert.Eq(t, "image/jpeg", nfs.ReaderMimeType(buf)) + buf.Reset() + + buf.Write([]byte("text")) + assert.Eq(t, "text/plain; charset=utf-8", nfs.ReaderMimeType(buf)) + buf.Reset() + + buf.Write([]byte("")) + assert.Eq(t, "", nfs.ReaderMimeType(buf)) + buf.Reset() + + assert.False(t, nfs.IsImageFile("testdata/test.txt")) + assert.False(t, nfs.IsImageFile("testdata/not-exists")) +} + +func TestTempDir(t *testing.T) { + dir, err := nfs.TempDir("testdata", "temp.*") + assert.NoErr(t, err) + assert.True(t, nfs.IsDir(dir)) + assert.NoErr(t, nfs.Remove(dir)) +} + +func TestSplitPath(t *testing.T) { + dir, file := nfs.SplitPath("/path/to/dir/some.txt") + assert.Eq(t, "/path/to/dir/", dir) + assert.Eq(t, "some.txt", file) +} + +func TestToAbsPath(t *testing.T) { + assert.Eq(t, "", nfs.ToAbsPath("")) + assert.Eq(t, "/path/to/dir/", nfs.ToAbsPath("/path/to/dir/")) + assert.Neq(t, "~/path/to/dir", nfs.ToAbsPath("~/path/to/dir")) + assert.Neq(t, ".", nfs.ToAbsPath(".")) + assert.Neq(t, "..", nfs.ToAbsPath("..")) + assert.Neq(t, "./", nfs.ToAbsPath("./")) + assert.Neq(t, "../", nfs.ToAbsPath("../")) +} + +func TestSlashPath(t *testing.T) { + assert.Eq(t, "/path/to/dir", nfs.SlashPath("/path/to/dir")) + assert.Eq(t, "/path/to/dir", nfs.UnixPath("/path/to/dir")) + assert.Eq(t, "/path/to/dir", nfs.UnixPath("\\path\\to\\dir")) +} diff --git a/nfs/util_windows_test.go b/nfs/util_windows_test.go new file mode 100644 index 0000000..b7251a1 --- /dev/null +++ b/nfs/util_windows_test.go @@ -0,0 +1,19 @@ +//go:build windows + +package nfs_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestSlashPath_win(t *testing.T) { + assert.Eq(t, "path\\to\\dir", nfs.JoinPaths("path", "to", "dir")) + assert.Eq(t, "path\\to\\dir", nfs.JoinSubPaths("path", "to", "dir")) +} + +func TestRealpath_win(t *testing.T) { + inPath := "/path/to/some/../dir" + assert.Eq(t, "\\path\\to\\dir", nfs.Realpath(inPath)) +} diff --git a/ngo/base_fn.go b/ngo/base_fn.go new file mode 100644 index 0000000..5ffdcbd --- /dev/null +++ b/ngo/base_fn.go @@ -0,0 +1,72 @@ +package ngo + +// Must if error is not empty, will panic +func Must(err error) { + if err != nil { + panic(err) + } +} + +// MustV if error is not empty, will panic. otherwise return the value. +func MustV[T any](v T, err error) T { + if err != nil { + panic(err) + } + return v +} + +// ErrOnFail return input error on cond is false, otherwise return nil +func ErrOnFail(cond bool, err error) error { + return OrError(cond, err) +} + +// OrError return input error on cond is false, otherwise return nil +func OrError(cond bool, err error) error { + if !cond { + return err + } + return nil +} + +// FirstOr get first elem or elseVal +func FirstOr[T any](sl []T, elseVal T) T { + if len(sl) > 0 { + return sl[0] + } + return elseVal +} + +// OrValue get +func OrValue[T any](cond bool, okVal, elVal T) T { + if cond { + return okVal + } + return elVal +} + +// OrReturn call okFunc() on condition is true, else call elseFn() +func OrReturn[T any](cond bool, okFn, elseFn func() T) T { + if cond { + return okFn() + } + return elseFn() +} + +// ErrFunc type +type ErrFunc func() error + +// CallOn call func on condition is true +func CallOn(cond bool, fn ErrFunc) error { + if cond { + return fn() + } + return nil +} + +// CallOrElse call okFunc() on condition is true, else call elseFn() +func CallOrElse(cond bool, okFn, elseFn ErrFunc) error { + if cond { + return okFn() + } + return elseFn() +} diff --git a/ngo/codec/serializer_json.go b/ngo/codec/serializer_json.go new file mode 100644 index 0000000..69f63e7 --- /dev/null +++ b/ngo/codec/serializer_json.go @@ -0,0 +1,25 @@ +package codec + +import ( + "encoding/json" + "git.noahlan.cn/noahlan/ntool/ndef" +) + +type JsonSerializer struct { +} + +func NewJsonSerializer() ndef.Serializer { + return &JsonSerializer{} +} + +func (s *JsonSerializer) Marshal(v interface{}) ([]byte, error) { + marshal, err := json.Marshal(v) + if err != nil { + return nil, err + } + return marshal, nil +} + +func (s *JsonSerializer) Unmarshal(data []byte, v interface{}) error { + return json.Unmarshal(data, v) +} diff --git a/ngo/ext_fn.go b/ngo/ext_fn.go new file mode 100644 index 0000000..593da45 --- /dev/null +++ b/ngo/ext_fn.go @@ -0,0 +1,23 @@ +package ngo + +import "fmt" + +// DataSize format bytes number friendly. eg: 1024 => 1KB, 1024*1024 => 1MB +// +// Usage: +// +// file, err := os.Open(path) +// fl, err := file.Stat() +// fmtSize := DataSize(fl.Size()) +func DataSize(size uint64) string { + switch { + case size < 1024: + return fmt.Sprintf("%dB", size) + case size < 1024*1024: + return fmt.Sprintf("%.2fK", float64(size)/1024) + case size < 1024*1024*1024: + return fmt.Sprintf("%.2fM", float64(size)/1024/1024) + default: + return fmt.Sprintf("%.2fG", float64(size)/1024/1024/1024) + } +} diff --git a/nlog/color.go b/nlog/color.go new file mode 100644 index 0000000..b5b58d1 --- /dev/null +++ b/nlog/color.go @@ -0,0 +1,22 @@ +package nlog + +import ( + "git.noahlan.cn/noahlan/ntool/nstr" + "github.com/gookit/color" + "sync/atomic" +) + +// WithColor is a helper function to add color to a string, only in plain encoding. +func WithColor(text string, colour color.Color) string { + if atomic.LoadUint32(&encoding) == plainEncodingType { + return colour.Render(text) + } + + return text +} + +// WithColorPadding is a helper function to add color to a string with leading and trailing spaces, +// only in plain encoding. +func WithColorPadding(text string, colour color.Color) string { + return WithColor(nstr.PadAround(text, " ", len(text)+2), colour) +} diff --git a/nlog/color_test.go b/nlog/color_test.go new file mode 100644 index 0000000..93ca983 --- /dev/null +++ b/nlog/color_test.go @@ -0,0 +1,32 @@ +package nlog + +import ( + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "github.com/gookit/color" + "sync/atomic" + "testing" +) + +func TestWithColor(t *testing.T) { + old := atomic.SwapUint32(&encoding, plainEncodingType) + defer atomic.StoreUint32(&encoding, old) + + output := WithColor("hello", color.BgBlue) + assert.Equal(t, "hello", output) + + atomic.StoreUint32(&encoding, jsonEncodingType) + output = WithColor("hello", color.BgBlue) + assert.Equal(t, "hello", output) +} + +func TestWithColorPadding(t *testing.T) { + old := atomic.SwapUint32(&encoding, plainEncodingType) + defer atomic.StoreUint32(&encoding, old) + + output := WithColorPadding("hello", color.BgBlue) + assert.Equal(t, " hello ", output) + + atomic.StoreUint32(&encoding, jsonEncodingType) + output = WithColorPadding("hello", color.BgBlue) + assert.Equal(t, "hello", output) +} diff --git a/nlog/config.go b/nlog/config.go new file mode 100644 index 0000000..3705eb9 --- /dev/null +++ b/nlog/config.go @@ -0,0 +1,45 @@ +package nlog + +// A LogConf is a logging config. +type LogConf struct { + // ServiceName represents the service name. + ServiceName string `json:",optional"` + // Mode represents the logging mode, default is `console`. + // console: log to console. + // file: log to file. + // volume: used in k8s, prepend the hostname to the log file name. + Mode string `json:",default=console,options=[console,file,volume]"` + // Encoding represents the encoding type, default is `plain`. + // json: json encoding. + // plain: plain text encoding, typically used in development. + Encoding string `json:",default=plain,options=[json,plain]"` + // TimeFormat represents the time format, default is `2006-01-02T15:04:05.000Z07:00`. + TimeFormat string `json:",optional"` + // Path represents the log file path, default is `logs`. + Path string `json:",default=logs"` + // Level represents the log level, default is `debug`. + Level string `json:",default=debug,options=[debug,info,error,severe]"` + // MaxContentLength represents the max content bytes, default is no limit. + MaxContentLength uint32 `json:",optional"` + // Compress represents whether to compress the log file, default is `false`. + Compress bool `json:",optional"` + // Stdout represents whether to log statistics, default is `true`. + Stat bool `json:",default=true"` + // KeepDays represents how many days the log files will be kept. Default to keep all files. + // Only take effect when Mode is `file` or `volume`, both work when Rotation is `daily` or `size`. + KeepDays int `json:",optional"` + // StackCooldownMillis represents the cooldown time for stack logging, default is 100ms. + StackCooldownMillis int `json:",default=100"` + // MaxBackups represents how many backup log files will be kept. 0 means all files will be kept forever. + // Only take effect when RotationRuleType is `size`. + // Even thougth `MaxBackups` sets 0, log files will still be removed + // if the `KeepDays` limitation is reached. + MaxBackups int `json:",default=0"` + // MaxSize represents how much space the writing log file takes up. 0 means no limit. The unit is `MB`. + // Only take effect when RotationRuleType is `size` + MaxSize int `json:",default=0"` + // RotationRuleType represents the type of log rotation rule. Default is `daily`. + // daily: daily rotation. + // size: size limited rotation. + Rotation string `json:",default=daily,options=[daily,size]"` +} diff --git a/nlog/fields.go b/nlog/fields.go new file mode 100644 index 0000000..51f4a88 --- /dev/null +++ b/nlog/fields.go @@ -0,0 +1,48 @@ +package nlog + +import ( + "context" + "sync" + "sync/atomic" +) + +var ( + fieldsContextKey contextKey + globalFields atomic.Value + globalFieldsLock sync.Mutex +) + +type contextKey struct{} + +// AddGlobalFields adds global fields. +func AddGlobalFields(fields ...LogField) { + globalFieldsLock.Lock() + defer globalFieldsLock.Unlock() + + old := globalFields.Load() + if old == nil { + globalFields.Store(append([]LogField(nil), fields...)) + } else { + globalFields.Store(append(old.([]LogField), fields...)) + } +} + +// ContextWithFields returns a new context with the given fields. +func ContextWithFields(ctx context.Context, fields ...LogField) context.Context { + if val := ctx.Value(fieldsContextKey); val != nil { + if arr, ok := val.([]LogField); ok { + allFields := make([]LogField, 0, len(arr)+len(fields)) + allFields = append(allFields, arr...) + allFields = append(allFields, fields...) + return context.WithValue(ctx, fieldsContextKey, allFields) + } + } + + return context.WithValue(ctx, fieldsContextKey, fields) +} + +// WithFields returns a new logger with the given fields. +// deprecated: use ContextWithFields instead. +func WithFields(ctx context.Context, fields ...LogField) context.Context { + return ContextWithFields(ctx, fields...) +} diff --git a/nlog/fields_test.go b/nlog/fields_test.go new file mode 100644 index 0000000..f9b72f7 --- /dev/null +++ b/nlog/fields_test.go @@ -0,0 +1,120 @@ +package nlog + +import ( + "bytes" + "context" + "encoding/json" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "strconv" + "sync" + "sync/atomic" + "testing" +) + +func TestAddGlobalFields(t *testing.T) { + var buf bytes.Buffer + writer := NewWriter(&buf) + old := Reset() + SetWriter(writer) + defer SetWriter(old) + + Info("hello") + buf.Reset() + + AddGlobalFields(Field("a", "1"), Field("b", "2")) + AddGlobalFields(Field("c", "3")) + Info("world") + var m map[string]any + assert.NoError(t, json.Unmarshal(buf.Bytes(), &m)) + assert.Equal(t, "1", m["a"]) + assert.Equal(t, "2", m["b"]) + assert.Equal(t, "3", m["c"]) +} + +func TestContextWithFields(t *testing.T) { + ctx := ContextWithFields(context.Background(), Field("a", 1), Field("b", 2)) + vals := ctx.Value(fieldsContextKey) + assert.NotNil(t, vals) + fields, ok := vals.([]LogField) + assert.True(t, ok) + assert.EqualValues(t, []LogField{Field("a", 1), Field("b", 2)}, fields) +} + +func TestWithFields(t *testing.T) { + ctx := WithFields(context.Background(), Field("a", 1), Field("b", 2)) + vals := ctx.Value(fieldsContextKey) + assert.NotNil(t, vals) + fields, ok := vals.([]LogField) + assert.True(t, ok) + assert.EqualValues(t, []LogField{Field("a", 1), Field("b", 2)}, fields) +} + +func TestWithFieldsAppend(t *testing.T) { + var dummyKey struct{} + ctx := context.WithValue(context.Background(), dummyKey, "dummy") + ctx = ContextWithFields(ctx, Field("a", 1), Field("b", 2)) + ctx = ContextWithFields(ctx, Field("c", 3), Field("d", 4)) + vals := ctx.Value(fieldsContextKey) + assert.NotNil(t, vals) + fields, ok := vals.([]LogField) + assert.True(t, ok) + assert.Equal(t, "dummy", ctx.Value(dummyKey)) + assert.EqualValues(t, []LogField{ + Field("a", 1), + Field("b", 2), + Field("c", 3), + Field("d", 4), + }, fields) +} + +func TestWithFieldsAppendCopy(t *testing.T) { + const count = 10 + ctx := context.Background() + for i := 0; i < count; i++ { + ctx = ContextWithFields(ctx, Field(strconv.Itoa(i), 1)) + } + + af := Field("foo", 1) + bf := Field("bar", 2) + ctxa := ContextWithFields(ctx, af) + ctxb := ContextWithFields(ctx, bf) + + assert.EqualValues(t, af, ctxa.Value(fieldsContextKey).([]LogField)[count]) + assert.EqualValues(t, bf, ctxb.Value(fieldsContextKey).([]LogField)[count]) +} + +func BenchmarkAtomicValue(b *testing.B) { + b.ReportAllocs() + + var container atomic.Value + vals := []LogField{ + Field("a", "b"), + Field("c", "d"), + Field("e", "f"), + } + container.Store(&vals) + + for i := 0; i < b.N; i++ { + val := container.Load() + if val != nil { + _ = *val.(*[]LogField) + } + } +} + +func BenchmarkRWMutex(b *testing.B) { + b.ReportAllocs() + + var lock sync.RWMutex + vals := []LogField{ + Field("a", "b"), + Field("c", "d"), + Field("e", "f"), + } + + for i := 0; i < b.N; i++ { + lock.RLock() + _ = vals + lock.RUnlock() + } +} diff --git a/nlog/lesslogger.go b/nlog/lesslogger.go new file mode 100644 index 0000000..a040a7d --- /dev/null +++ b/nlog/lesslogger.go @@ -0,0 +1,27 @@ +package nlog + +// A LessLogger is a logger that control to log once during the given duration. +type LessLogger struct { + *limitedExecutor +} + +// NewLessLogger returns a LessLogger. +func NewLessLogger(milliseconds int) *LessLogger { + return &LessLogger{ + limitedExecutor: newLimitedExecutor(milliseconds), + } +} + +// Error logs v into error log or discard it if more than once in the given duration. +func (logger *LessLogger) Error(v ...any) { + logger.logOrDiscard(func() { + Error(v...) + }) +} + +// Errorf logs v with format into error log or discard it if more than once in the given duration. +func (logger *LessLogger) Errorf(format string, v ...any) { + logger.logOrDiscard(func() { + Errorf(format, v...) + }) +} diff --git a/nlog/lesslogger_test.go b/nlog/lesslogger_test.go new file mode 100644 index 0000000..d3628df --- /dev/null +++ b/nlog/lesslogger_test.go @@ -0,0 +1,34 @@ +package nlog + +import ( + "strings" + "testing" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +func TestLessLogger_Error(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + l := NewLessLogger(500) + for i := 0; i < 100; i++ { + l.Error("hello") + } + + assert.Equal(t, 1, strings.Count(w.String(), "\n")) +} + +func TestLessLogger_Errorf(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + l := NewLessLogger(500) + for i := 0; i < 100; i++ { + l.Errorf("hello") + } + + assert.Equal(t, 1, strings.Count(w.String(), "\n")) +} diff --git a/nlog/lesswriter.go b/nlog/lesswriter.go new file mode 100644 index 0000000..73a2b09 --- /dev/null +++ b/nlog/lesswriter.go @@ -0,0 +1,22 @@ +package nlog + +import "io" + +type lessWriter struct { + *limitedExecutor + writer io.Writer +} + +func newLessWriter(writer io.Writer, milliseconds int) *lessWriter { + return &lessWriter{ + limitedExecutor: newLimitedExecutor(milliseconds), + writer: writer, + } +} + +func (w *lessWriter) Write(p []byte) (n int, err error) { + w.logOrDiscard(func() { + w.writer.Write(p) + }) + return len(p), nil +} diff --git a/nlog/lesswriter_test.go b/nlog/lesswriter_test.go new file mode 100644 index 0000000..335a830 --- /dev/null +++ b/nlog/lesswriter_test.go @@ -0,0 +1,19 @@ +package nlog + +import ( + "strings" + "testing" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +func TestLessWriter(t *testing.T) { + var builder strings.Builder + w := newLessWriter(&builder, 500) + for i := 0; i < 100; i++ { + _, err := w.Write([]byte("hello")) + assert.Nil(t, err) + } + + assert.Equal(t, "hello", builder.String()) +} diff --git a/nlog/limitedexecutor.go b/nlog/limitedexecutor.go new file mode 100644 index 0000000..e0a21df --- /dev/null +++ b/nlog/limitedexecutor.go @@ -0,0 +1,40 @@ +package nlog + +import ( + natomic "git.noahlan.cn/noahlan/ntool/nsys/atomic" + "sync/atomic" + "time" +) + +type limitedExecutor struct { + threshold time.Duration + lastTime *natomic.AtomicDuration + discarded uint32 +} + +func newLimitedExecutor(milliseconds int) *limitedExecutor { + return &limitedExecutor{ + threshold: time.Duration(milliseconds) * time.Millisecond, + lastTime: natomic.NewAtomicDuration(), + } +} + +func (le *limitedExecutor) logOrDiscard(execute func()) { + if le == nil || le.threshold <= 0 { + execute() + return + } + + now := time.Since(time.Now().AddDate(-1, -1, -1)) + if now-le.lastTime.Load() <= le.threshold { + atomic.AddUint32(&le.discarded, 1) + } else { + le.lastTime.Set(now) + discarded := atomic.SwapUint32(&le.discarded, 0) + if discarded > 0 { + Errorf("Discarded %d error messages", discarded) + } + + execute() + } +} diff --git a/nlog/limitedexecutor_test.go b/nlog/limitedexecutor_test.go new file mode 100644 index 0000000..b0b59d7 --- /dev/null +++ b/nlog/limitedexecutor_test.go @@ -0,0 +1,61 @@ +package nlog + +import ( + "sync/atomic" + "testing" + "time" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +func TestLimitedExecutor_logOrDiscard(t *testing.T) { + tests := []struct { + name string + threshold time.Duration + lastTime time.Duration + discarded uint32 + executed bool + }{ + { + name: "nil executor", + executed: true, + }, + { + name: "regular", + threshold: time.Hour, + lastTime: time.Since(time.Now().AddDate(-1, -1, -1)), + discarded: 10, + executed: false, + }, + { + name: "slow", + threshold: time.Duration(1), + lastTime: -1000, + discarded: 10, + executed: true, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + executor := newLimitedExecutor(0) + executor.threshold = test.threshold + executor.discarded = test.discarded + executor.lastTime.Set(test.lastTime) + + var run int32 + executor.logOrDiscard(func() { + atomic.AddInt32(&run, 1) + }) + if test.executed { + assert.Equal(t, int32(1), atomic.LoadInt32(&run)) + } else { + assert.Equal(t, int32(0), atomic.LoadInt32(&run)) + assert.Equal(t, test.discarded+1, atomic.LoadUint32(&executor.discarded)) + } + }) + } +} diff --git a/nlog/logger.go b/nlog/logger.go new file mode 100644 index 0000000..b871661 --- /dev/null +++ b/nlog/logger.go @@ -0,0 +1,50 @@ +package nlog + +import ( + "context" + "time" +) + +// A Logger represents a logger. +type Logger interface { + // Debug logs a message at info level. + Debug(...any) + // Debugf logs a message at info level. + Debugf(string, ...any) + // Debugv logs a message at info level. + Debugv(any) + // Debugw logs a message at info level. + Debugw(string, ...LogField) + // Error logs a message at error level. + Error(...any) + // Errorf logs a message at error level. + Errorf(string, ...any) + // Errorv logs a message at error level. + Errorv(any) + // Errorw logs a message at error level. + Errorw(string, ...LogField) + // Info logs a message at info level. + Info(...any) + // Infof logs a message at info level. + Infof(string, ...any) + // Infov logs a message at info level. + Infov(any) + // Infow logs a message at info level. + Infow(string, ...LogField) + // Slow logs a message at slow level. + Slow(...any) + // Slowf logs a message at slow level. + Slowf(string, ...any) + // Slowv logs a message at slow level. + Slowv(any) + // Sloww logs a message at slow level. + Sloww(string, ...LogField) + // WithCallerSkip returns a new logger with the given caller skip. + WithCallerSkip(skip int) Logger + // WithContext returns a new logger with the given context. + WithContext(ctx context.Context) Logger + // WithDuration returns a new logger with the given duration. + WithDuration(d time.Duration) Logger + // WithFields returns a new logger with the given fields. + WithFields(fields ...LogField) Logger +} diff --git a/nlog/logs.go b/nlog/logs.go new file mode 100644 index 0000000..02609e3 --- /dev/null +++ b/nlog/logs.go @@ -0,0 +1,464 @@ +package nlog + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nsys" + "io" + "log" + "os" + "path" + "runtime/debug" + "sync" + "sync/atomic" + "time" +) + +const CallerDepth = 4 + +var ( + timeFormat = "2006-01-02T15:04:05.000Z07:00" + logLevel uint32 = DebugLevel + encoding uint32 = plainEncodingType + // maxContentLength is used to truncate the log content, 0 for not truncating. + maxContentLength uint32 + // use uint32 for atomic operations + disableLog uint32 + disableStat uint32 + options logOptions + writer = new(atomicWriter) + setupOnce sync.Once +) + +type ( + // LogField is a key-value pair that will be added to the log entry. + LogField struct { + Key string + Value any + } + + // LogOption defines the method to customize the logging. + LogOption func(options *logOptions) + + logEntry map[string]any + + logOptions struct { + gzipEnabled bool + logStackCooldownMills int + keepDays int + maxBackups int + maxSize int + rotationRule string + } +) + +// Field returns a LogField for the given key and value. +func Field(key string, value any) LogField { + switch val := value.(type) { + case error: + return LogField{Key: key, Value: val.Error()} + case []error: + var errs []string + for _, err := range val { + errs = append(errs, err.Error()) + } + return LogField{Key: key, Value: errs} + case time.Duration: + return LogField{Key: key, Value: fmt.Sprint(val)} + case []time.Duration: + var durs []string + for _, dur := range val { + durs = append(durs, fmt.Sprint(dur)) + } + return LogField{Key: key, Value: durs} + case []time.Time: + var times []string + for _, t := range val { + times = append(times, fmt.Sprint(t)) + } + return LogField{Key: key, Value: times} + case fmt.Stringer: + return LogField{Key: key, Value: val.String()} + case []fmt.Stringer: + var strs []string + for _, str := range val { + strs = append(strs, str.String()) + } + return LogField{Key: key, Value: strs} + default: + return LogField{Key: key, Value: val} + } +} + +// Alert alerts v in alert level, and the message is written to error log. +func Alert(v string) { + GetWriter().Alert(v) +} + +// Debug writes v into access log. +func Debug(v ...any) { + writeDebug(fmt.Sprint(v...)) +} + +// Debugf writes v with format into access log. +func Debugf(format string, v ...any) { + writeDebug(fmt.Sprintf(format, v...)) +} + +// Debugv writes v into access log with json content. +func Debugv(v any) { + writeDebug(v) +} + +// Debugw writes msg along with fields into access log. +func Debugw(msg string, fields ...LogField) { + writeDebug(msg, fields...) +} + +// Error writes v into error log. +func Error(v ...any) { + writeError(fmt.Sprint(v...)) +} + +// Errorf writes v with format into error log. +func Errorf(format string, v ...any) { + writeError(fmt.Errorf(format, v...).Error()) +} + +// ErrorStack writes v along with call stack into error log. +func ErrorStack(v ...any) { + // there is newline in stack string + writeStack(fmt.Sprint(v...)) +} + +// ErrorStackf writes v along with call stack in format into error log. +func ErrorStackf(format string, v ...any) { + // there is newline in stack string + writeStack(fmt.Sprintf(format, v...)) +} + +// Errorv writes v into error log with json content. +// No call stack attached, because not elegant to pack the messages. +func Errorv(v any) { + writeError(v) +} + +// Errorw writes msg along with fields into error log. +func Errorw(msg string, fields ...LogField) { + writeError(msg, fields...) +} + +// Must checks if err is nil, otherwise logs the error and exits. +func Must(err error) { + if err == nil { + return + } + + msg := err.Error() + log.Print(msg) + GetWriter().Severe(msg) + os.Exit(1) +} + +// Info writes v into access log. +func Info(v ...any) { + writeInfo(fmt.Sprint(v...)) +} + +// Infof writes v with format into access log. +func Infof(format string, v ...any) { + writeInfo(fmt.Sprintf(format, v...)) +} + +// Infov writes v into access log with json content. +func Infov(v any) { + writeInfo(v) +} + +// Infow writes msg along with fields into access log. +func Infow(msg string, fields ...LogField) { + writeInfo(msg, fields...) +} + +// Severe writes v into severe log. +func Severe(v ...any) { + writeSevere(fmt.Sprint(v...)) +} + +// Severef writes v with format into severe log. +func Severef(format string, v ...any) { + writeSevere(fmt.Sprintf(format, v...)) +} + +// Slow writes v into slow log. +func Slow(v ...any) { + writeSlow(fmt.Sprint(v...)) +} + +// Slowf writes v with format into slow log. +func Slowf(format string, v ...any) { + writeSlow(fmt.Sprintf(format, v...)) +} + +// Slowv writes v into slow log with json content. +func Slowv(v any) { + writeSlow(v) +} + +// Sloww writes msg along with fields into slow log. +func Sloww(msg string, fields ...LogField) { + writeSlow(msg, fields...) +} + +// Stat writes v into stat log. +func Stat(v ...any) { + writeStat(fmt.Sprint(v...)) +} + +// Statf writes v with format into stat log. +func Statf(format string, v ...any) { + writeStat(fmt.Sprintf(format, v...)) +} + +// WithCooldownMillis customizes logging on writing call stack interval. +func WithCooldownMillis(millis int) LogOption { + return func(opts *logOptions) { + opts.logStackCooldownMills = millis + } +} + +// WithKeepDays customizes logging to keep logs with days. +func WithKeepDays(days int) LogOption { + return func(opts *logOptions) { + opts.keepDays = days + } +} + +// WithGzip customizes logging to automatically gzip the log files. +func WithGzip() LogOption { + return func(opts *logOptions) { + opts.gzipEnabled = true + } +} + +// WithMaxBackups customizes how many log files backups will be kept. +func WithMaxBackups(count int) LogOption { + return func(opts *logOptions) { + opts.maxBackups = count + } +} + +// WithMaxSize customizes how much space the writing log file can take up. +func WithMaxSize(size int) LogOption { + return func(opts *logOptions) { + opts.maxSize = size + } +} + +// WithRotation customizes which log rotation rule to use. +func WithRotation(r string) LogOption { + return func(opts *logOptions) { + opts.rotationRule = r + } +} + +// MustSetup sets up logging with given config c. It exits on error. +func MustSetup(c LogConf) { + Must(SetUp(c)) +} + +// Reset clears the writer and resets the log level. +func Reset() Writer { + return writer.Swap(nil) +} + +// SetLevel sets the logging level. It can be used to suppress some logs. +func SetLevel(level uint32) { + atomic.StoreUint32(&logLevel, level) +} + +// SetWriter sets the logging writer. It can be used to customize the logging. +func SetWriter(w Writer) { + if atomic.LoadUint32(&disableLog) == 0 { + writer.Store(w) + } +} + +// SetUp sets up the logx. If already set up, just return nil. +// we allow SetUp to be called multiple times, because for example +// we need to allow different service frameworks to initialize logx respectively. +func SetUp(c LogConf) (err error) { + // Just ignore the subsequent SetUp calls. + // Because multiple services in one process might call SetUp respectively. + // Need to wait for the first caller to complete the execution. + setupOnce.Do(func() { + setupLogLevel(c) + + if !c.Stat { + DisableStat() + } + + if len(c.TimeFormat) > 0 { + timeFormat = c.TimeFormat + } + + atomic.StoreUint32(&maxContentLength, c.MaxContentLength) + + switch c.Encoding { + case plainEncoding: + atomic.StoreUint32(&encoding, plainEncodingType) + default: + atomic.StoreUint32(&encoding, jsonEncodingType) + } + + switch c.Mode { + case fileMode: + err = setupWithFiles(c) + case volumeMode: + err = setupWithVolume(c) + default: + setupWithConsole() + } + }) + + return +} + +// Close closes the logging. +func Close() error { + if w := writer.Swap(nil); w != nil { + return w.(io.Closer).Close() + } + + return nil +} + +// Disable disables the logging. +func Disable() { + atomic.StoreUint32(&disableLog, 1) + writer.Store(nopWriter{}) +} + +// DisableStat disables the stat logs. +func DisableStat() { + atomic.StoreUint32(&disableStat, 1) +} + +func addCaller(fields ...LogField) []LogField { + return append(fields, Field(callerKey, getCaller(CallerDepth))) +} + +func createOutput(path string) (io.WriteCloser, error) { + if len(path) == 0 { + return nil, ErrLogPathNotSet + } + + switch options.rotationRule { + case sizeRotationRule: + return NewLogger(path, NewSizeLimitRotateRule(path, backupFileDelimiter, options.keepDays, + options.maxSize, options.maxBackups, options.gzipEnabled), options.gzipEnabled) + default: + return NewLogger(path, DefaultRotateRule(path, backupFileDelimiter, options.keepDays, + options.gzipEnabled), options.gzipEnabled) + } +} + +func GetWriter() Writer { + w := writer.Load() + if w == nil { + w = writer.StoreIfNil(newConsoleWriter()) + } + + return w +} + +func handleOptions(opts []LogOption) { + for _, opt := range opts { + opt(&options) + } +} + +func setupLogLevel(c LogConf) { + switch c.Level { + case levelDebug: + SetLevel(DebugLevel) + case levelInfo: + SetLevel(InfoLevel) + case levelError: + SetLevel(ErrorLevel) + case levelSevere: + SetLevel(SevereLevel) + } +} + +func setupWithConsole() { + SetWriter(newConsoleWriter()) +} + +func setupWithFiles(c LogConf) error { + w, err := newFileWriter(c) + if err != nil { + return err + } + + SetWriter(w) + return nil +} + +func setupWithVolume(c LogConf) error { + if len(c.ServiceName) == 0 { + return ErrLogServiceNameNotSet + } + + c.Path = path.Join(c.Path, c.ServiceName, nsys.Hostname()) + return setupWithFiles(c) +} + +func shallLog(level uint32) bool { + return atomic.LoadUint32(&logLevel) <= level +} + +func shallLogStat() bool { + return atomic.LoadUint32(&disableStat) == 0 +} + +func writeDebug(val any, fields ...LogField) { + if shallLog(DebugLevel) { + GetWriter().Debug(val, addCaller(fields...)...) + } +} + +func writeError(val any, fields ...LogField) { + if shallLog(ErrorLevel) { + GetWriter().Error(val, addCaller(fields...)...) + } +} + +func writeInfo(val any, fields ...LogField) { + if shallLog(InfoLevel) { + GetWriter().Info(val, addCaller(fields...)...) + } +} + +func writeSevere(msg string) { + if shallLog(SevereLevel) { + GetWriter().Severe(fmt.Sprintf("%s\n%s", msg, string(debug.Stack()))) + } +} + +func writeSlow(val any, fields ...LogField) { + if shallLog(ErrorLevel) { + GetWriter().Slow(val, addCaller(fields...)...) + } +} + +func writeStack(msg string) { + if shallLog(ErrorLevel) { + GetWriter().Stack(fmt.Sprintf("%s\n%s", msg, string(debug.Stack()))) + } +} + +func writeStat(msg string) { + if shallLogStat() && shallLog(InfoLevel) { + GetWriter().Stat(msg, addCaller()...) + } +} diff --git a/nlog/logs_test.go b/nlog/logs_test.go new file mode 100644 index 0000000..c965715 --- /dev/null +++ b/nlog/logs_test.go @@ -0,0 +1,838 @@ +package nlog + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "log" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +var ( + s = []byte("Sending #11 notification (id: 1451875113812010473) in #1 connection") + pool = make(chan []byte, 1) + _ Writer = (*mockWriter)(nil) +) + +type mockWriter struct { + lock sync.Mutex + builder strings.Builder +} + +func (mw *mockWriter) Alert(v any) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelAlert, v) +} + +func (mw *mockWriter) Debug(v any, fields ...LogField) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelDebug, v, fields...) +} + +func (mw *mockWriter) Error(v any, fields ...LogField) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelError, v, fields...) +} + +func (mw *mockWriter) Info(v any, fields ...LogField) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelInfo, v, fields...) +} + +func (mw *mockWriter) Severe(v any) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelSevere, v) +} + +func (mw *mockWriter) Slow(v any, fields ...LogField) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelSlow, v, fields...) +} + +func (mw *mockWriter) Stack(v any) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelError, v) +} + +func (mw *mockWriter) Stat(v any, fields ...LogField) { + mw.lock.Lock() + defer mw.lock.Unlock() + output(&mw.builder, levelStat, v, fields...) +} + +func (mw *mockWriter) Close() error { + return nil +} + +func (mw *mockWriter) Contains(text string) bool { + mw.lock.Lock() + defer mw.lock.Unlock() + return strings.Contains(mw.builder.String(), text) +} + +func (mw *mockWriter) Reset() { + mw.lock.Lock() + defer mw.lock.Unlock() + mw.builder.Reset() +} + +func (mw *mockWriter) String() string { + mw.lock.Lock() + defer mw.lock.Unlock() + return mw.builder.String() +} + +func TestField(t *testing.T) { + tests := []struct { + name string + f LogField + want map[string]any + }{ + { + name: "error", + f: Field("foo", errors.New("bar")), + want: map[string]any{ + "foo": "bar", + }, + }, + { + name: "errors", + f: Field("foo", []error{errors.New("bar"), errors.New("baz")}), + want: map[string]any{ + "foo": []any{"bar", "baz"}, + }, + }, + { + name: "strings", + f: Field("foo", []string{"bar", "baz"}), + want: map[string]any{ + "foo": []any{"bar", "baz"}, + }, + }, + { + name: "duration", + f: Field("foo", time.Second), + want: map[string]any{ + "foo": "1s", + }, + }, + { + name: "durations", + f: Field("foo", []time.Duration{time.Second, 2 * time.Second}), + want: map[string]any{ + "foo": []any{"1s", "2s"}, + }, + }, + { + name: "times", + f: Field("foo", []time.Time{ + time.Date(2020, time.January, 1, 0, 0, 0, 0, time.UTC), + time.Date(2020, time.January, 2, 0, 0, 0, 0, time.UTC), + }), + want: map[string]any{ + "foo": []any{"2020-01-01 00:00:00 +0000 UTC", "2020-01-02 00:00:00 +0000 UTC"}, + }, + }, + { + name: "stringer", + f: Field("foo", ValStringer{val: "bar"}), + want: map[string]any{ + "foo": "bar", + }, + }, + { + name: "stringers", + f: Field("foo", []fmt.Stringer{ValStringer{val: "bar"}, ValStringer{val: "baz"}}), + want: map[string]any{ + "foo": []any{"bar", "baz"}, + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + Infow("foo", test.f) + validateFields(t, w.String(), test.want) + }) + } +} + +func TestFileLineFileMode(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + file, line := getFileLine() + Error("anything") + assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) + + file, line = getFileLine() + Errorf("anything %s", "format") + assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) +} + +func TestFileLineConsoleMode(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + file, line := getFileLine() + Error("anything") + assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) + + w.Reset() + file, line = getFileLine() + Errorf("anything %s", "format") + assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) +} + +func TestStructedLogAlert(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelAlert, w, func(v ...any) { + Alert(fmt.Sprint(v...)) + }) +} + +func TestStructedLogDebug(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelDebug, w, func(v ...any) { + Debug(v...) + }) +} + +func TestStructedLogDebugf(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelDebug, w, func(v ...any) { + Debugf(fmt.Sprint(v...)) + }) +} + +func TestStructedLogDebugv(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelDebug, w, func(v ...any) { + Debugv(fmt.Sprint(v...)) + }) +} + +func TestStructedLogDebugw(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelDebug, w, func(v ...any) { + Debugw(fmt.Sprint(v...), Field("foo", time.Second)) + }) +} + +func TestStructedLogError(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelError, w, func(v ...any) { + Error(v...) + }) +} + +func TestStructedLogErrorf(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelError, w, func(v ...any) { + Errorf("%s", fmt.Sprint(v...)) + }) +} + +func TestStructedLogErrorv(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelError, w, func(v ...any) { + Errorv(fmt.Sprint(v...)) + }) +} + +func TestStructedLogErrorw(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelError, w, func(v ...any) { + Errorw(fmt.Sprint(v...), Field("foo", "bar")) + }) +} + +func TestStructedLogInfo(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelInfo, w, func(v ...any) { + Info(v...) + }) +} + +func TestStructedLogInfof(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelInfo, w, func(v ...any) { + Infof("%s", fmt.Sprint(v...)) + }) +} + +func TestStructedLogInfov(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelInfo, w, func(v ...any) { + Infov(fmt.Sprint(v...)) + }) +} + +func TestStructedLogInfow(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelInfo, w, func(v ...any) { + Infow(fmt.Sprint(v...), Field("foo", "bar")) + }) +} + +func TestStructedLogInfoConsoleAny(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLogConsole(t, w, func(v ...any) { + old := atomic.LoadUint32(&encoding) + atomic.StoreUint32(&encoding, plainEncodingType) + defer func() { + atomic.StoreUint32(&encoding, old) + }() + + Infov(v) + }) +} + +func TestStructedLogInfoConsoleAnyString(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLogConsole(t, w, func(v ...any) { + old := atomic.LoadUint32(&encoding) + atomic.StoreUint32(&encoding, plainEncodingType) + defer func() { + atomic.StoreUint32(&encoding, old) + }() + + Infov(fmt.Sprint(v...)) + }) +} + +func TestStructedLogInfoConsoleAnyError(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLogConsole(t, w, func(v ...any) { + old := atomic.LoadUint32(&encoding) + atomic.StoreUint32(&encoding, plainEncodingType) + defer func() { + atomic.StoreUint32(&encoding, old) + }() + + Infov(errors.New(fmt.Sprint(v...))) + }) +} + +func TestStructedLogInfoConsoleAnyStringer(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLogConsole(t, w, func(v ...any) { + old := atomic.LoadUint32(&encoding) + atomic.StoreUint32(&encoding, plainEncodingType) + defer func() { + atomic.StoreUint32(&encoding, old) + }() + + Infov(ValStringer{ + val: fmt.Sprint(v...), + }) + }) +} + +func TestStructedLogInfoConsoleText(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLogConsole(t, w, func(v ...any) { + old := atomic.LoadUint32(&encoding) + atomic.StoreUint32(&encoding, plainEncodingType) + defer func() { + atomic.StoreUint32(&encoding, old) + }() + + Info(fmt.Sprint(v...)) + }) +} + +func TestStructedLogSlow(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelSlow, w, func(v ...any) { + Slow(v...) + }) +} + +func TestStructedLogSlowf(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelSlow, w, func(v ...any) { + Slowf(fmt.Sprint(v...)) + }) +} + +func TestStructedLogSlowv(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelSlow, w, func(v ...any) { + Slowv(fmt.Sprint(v...)) + }) +} + +func TestStructedLogSloww(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelSlow, w, func(v ...any) { + Sloww(fmt.Sprint(v...), Field("foo", time.Second)) + }) +} + +func TestStructedLogStat(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelStat, w, func(v ...any) { + Stat(v...) + }) +} + +func TestStructedLogStatf(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelStat, w, func(v ...any) { + Statf(fmt.Sprint(v...)) + }) +} + +func TestStructedLogSevere(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelSevere, w, func(v ...any) { + Severe(v...) + }) +} + +func TestStructedLogSeveref(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + doTestStructedLog(t, levelSevere, w, func(v ...any) { + Severef(fmt.Sprint(v...)) + }) +} + +func TestStructedLogWithDuration(t *testing.T) { + const message = "hello there" + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + WithDuration(time.Second).Info(message) + var entry map[string]any + if err := json.Unmarshal([]byte(w.String()), &entry); err != nil { + t.Error(err) + } + assert.Equal(t, levelInfo, entry[levelKey]) + assert.Equal(t, message, entry[contentKey]) + assert.Equal(t, "1000.0ms", entry[durationKey]) +} + +func TestSetLevel(t *testing.T) { + SetLevel(ErrorLevel) + const message = "hello there" + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + Info(message) + assert.Equal(t, 0, w.builder.Len()) +} + +func TestSetLevelTwiceWithMode(t *testing.T) { + testModes := []string{ + "console", + "volumn", + "mode", + } + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + for _, mode := range testModes { + testSetLevelTwiceWithMode(t, mode, w) + } +} + +func TestSetLevelWithDuration(t *testing.T) { + SetLevel(ErrorLevel) + const message = "hello there" + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + WithDuration(time.Second).Info(message) + assert.Equal(t, 0, w.builder.Len()) +} + +func TestErrorfWithWrappedError(t *testing.T) { + SetLevel(ErrorLevel) + const message = "there" + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + Errorf("hello %w", errors.New(message)) + assert.True(t, strings.Contains(w.String(), "hello there")) +} + +func TestMustNil(t *testing.T) { + Must(nil) +} + +func TestSetup(t *testing.T) { + defer func() { + SetLevel(InfoLevel) + atomic.StoreUint32(&encoding, jsonEncodingType) + }() + + MustSetup(LogConf{ + ServiceName: "any", + Mode: "console", + TimeFormat: timeFormat, + }) + MustSetup(LogConf{ + ServiceName: "any", + Mode: "file", + Path: os.TempDir(), + }) + MustSetup(LogConf{ + ServiceName: "any", + Mode: "volume", + Path: os.TempDir(), + }) + MustSetup(LogConf{ + ServiceName: "any", + Mode: "console", + TimeFormat: timeFormat, + }) + MustSetup(LogConf{ + ServiceName: "any", + Mode: "console", + Encoding: plainEncoding, + }) + + defer os.RemoveAll("CD01CB7D-2705-4F3F-889E-86219BF56F10") + assert.NotNil(t, setupWithVolume(LogConf{})) + assert.Nil(t, setupWithVolume(LogConf{ + ServiceName: "CD01CB7D-2705-4F3F-889E-86219BF56F10", + })) + assert.Nil(t, setupWithVolume(LogConf{ + ServiceName: "CD01CB7D-2705-4F3F-889E-86219BF56F10", + Rotation: sizeRotationRule, + })) + assert.NotNil(t, setupWithFiles(LogConf{})) + assert.Nil(t, setupWithFiles(LogConf{ + ServiceName: "any", + Path: os.TempDir(), + Compress: true, + KeepDays: 1, + MaxBackups: 3, + MaxSize: 1024 * 1024, + })) + setupLogLevel(LogConf{ + Level: levelInfo, + }) + setupLogLevel(LogConf{ + Level: levelError, + }) + setupLogLevel(LogConf{ + Level: levelSevere, + }) + _, err := createOutput("") + assert.NotNil(t, err) + Disable() + SetLevel(InfoLevel) + atomic.StoreUint32(&encoding, jsonEncodingType) +} + +func TestDisable(t *testing.T) { + Disable() + + var opt logOptions + WithKeepDays(1)(&opt) + WithGzip()(&opt) + WithMaxBackups(1)(&opt) + WithMaxSize(1024)(&opt) + assert.Nil(t, Close()) + assert.Nil(t, Close()) +} + +func TestDisableStat(t *testing.T) { + DisableStat() + + const message = "hello there" + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + Stat(message) + assert.Equal(t, 0, w.builder.Len()) +} + +func TestSetWriter(t *testing.T) { + atomic.StoreUint32(&disableLog, 0) + Reset() + SetWriter(nopWriter{}) + assert.NotNil(t, writer.Load()) + assert.True(t, writer.Load() == nopWriter{}) + mocked := new(mockWriter) + SetWriter(mocked) + assert.Equal(t, mocked, writer.Load()) +} + +func TestWithGzip(t *testing.T) { + fn := WithGzip() + var opt logOptions + fn(&opt) + assert.True(t, opt.gzipEnabled) +} + +func TestWithKeepDays(t *testing.T) { + fn := WithKeepDays(1) + var opt logOptions + fn(&opt) + assert.Equal(t, 1, opt.keepDays) +} + +func BenchmarkCopyByteSliceAppend(b *testing.B) { + for i := 0; i < b.N; i++ { + var buf []byte + buf = append(buf, getTimestamp()...) + buf = append(buf, ' ') + buf = append(buf, s...) + _ = buf + } +} + +func BenchmarkCopyByteSliceAllocExactly(b *testing.B) { + for i := 0; i < b.N; i++ { + now := []byte(getTimestamp()) + buf := make([]byte, len(now)+1+len(s)) + n := copy(buf, now) + buf[n] = ' ' + copy(buf[n+1:], s) + } +} + +func BenchmarkCopyByteSlice(b *testing.B) { + var buf []byte + for i := 0; i < b.N; i++ { + buf = make([]byte, len(s)) + copy(buf, s) + } + fmt.Fprint(io.Discard, buf) +} + +func BenchmarkCopyOnWriteByteSlice(b *testing.B) { + var buf []byte + for i := 0; i < b.N; i++ { + size := len(s) + buf = s[:size:size] + } + fmt.Fprint(io.Discard, buf) +} + +func BenchmarkCacheByteSlice(b *testing.B) { + for i := 0; i < b.N; i++ { + dup := fetch() + copy(dup, s) + put(dup) + } +} + +func BenchmarkLogs(b *testing.B) { + b.ReportAllocs() + + log.SetOutput(io.Discard) + for i := 0; i < b.N; i++ { + Info(i) + } +} + +func fetch() []byte { + select { + case b := <-pool: + return b + default: + } + return make([]byte, 4096) +} + +func getFileLine() (string, int) { + _, file, line, _ := runtime.Caller(1) + short := file + + for i := len(file) - 1; i > 0; i-- { + if file[i] == '/' { + short = file[i+1:] + break + } + } + + return short, line +} + +func put(b []byte) { + select { + case pool <- b: + default: + } +} + +func doTestStructedLog(t *testing.T, level string, w *mockWriter, write func(...any)) { + const message = "hello there" + write(message) + + var entry map[string]any + if err := json.Unmarshal([]byte(w.String()), &entry); err != nil { + t.Error(err) + } + + assert.Equal(t, level, entry[levelKey]) + val, ok := entry[contentKey] + assert.True(t, ok) + assert.True(t, strings.Contains(val.(string), message)) +} + +func doTestStructedLogConsole(t *testing.T, w *mockWriter, write func(...any)) { + const message = "hello there" + write(message) + assert.True(t, strings.Contains(w.String(), message)) +} + +func testSetLevelTwiceWithMode(t *testing.T, mode string, w *mockWriter) { + writer.Store(nil) + SetUp(LogConf{ + Mode: mode, + Level: "debug", + Path: "/dev/null", + Encoding: plainEncoding, + Stat: false, + TimeFormat: time.RFC3339, + }) + SetUp(LogConf{ + Mode: mode, + Level: "info", + Path: "/dev/null", + }) + const message = "hello there" + Info(message) + assert.Equal(t, 0, w.builder.Len()) + Infof(message) + assert.Equal(t, 0, w.builder.Len()) + ErrorStack(message) + assert.Equal(t, 0, w.builder.Len()) + ErrorStackf(message) + assert.Equal(t, 0, w.builder.Len()) +} + +type ValStringer struct { + val string +} + +func (v ValStringer) String() string { + return v.val +} + +func validateFields(t *testing.T, content string, fields map[string]any) { + var m map[string]any + if err := json.Unmarshal([]byte(content), &m); err != nil { + t.Error(err) + } + + for k, v := range fields { + if reflect.TypeOf(v).Kind() == reflect.Slice { + assert.EqualValues(t, v, m[k]) + } else { + assert.Equal(t, v, m[k], content) + } + } +} diff --git a/nlog/logwriter.go b/nlog/logwriter.go new file mode 100644 index 0000000..1de1a88 --- /dev/null +++ b/nlog/logwriter.go @@ -0,0 +1,22 @@ +package nlog + +import "log" + +type logWriter struct { + logger *log.Logger +} + +func newLogWriter(logger *log.Logger) logWriter { + return logWriter{ + logger: logger, + } +} + +func (lw logWriter) Close() error { + return nil +} + +func (lw logWriter) Write(data []byte) (int, error) { + lw.logger.Print(string(data)) + return len(data), nil +} diff --git a/nlog/richlogger.go b/nlog/richlogger.go new file mode 100644 index 0000000..021892a --- /dev/null +++ b/nlog/richlogger.go @@ -0,0 +1,179 @@ +package nlog + +import ( + "context" + "fmt" + "git.noahlan.cn/noahlan/ntool/ntime" + "time" +) + +// WithCallerSkip returns a Logger with given caller skip. +func WithCallerSkip(skip int) Logger { + if skip <= 0 { + return new(richLogger) + } + + return &richLogger{ + callerSkip: skip, + } +} + +// WithContext sets ctx to log, for keeping tracing information. +func WithContext(ctx context.Context) Logger { + return &richLogger{ + ctx: ctx, + } +} + +// WithDuration returns a Logger with given duration. +func WithDuration(d time.Duration) Logger { + return &richLogger{ + fields: []LogField{Field(durationKey, ntime.ReprOfDuration(d))}, + } +} + +type richLogger struct { + ctx context.Context + callerSkip int + fields []LogField +} + +func (l *richLogger) Debug(v ...any) { + l.debug(fmt.Sprint(v...)) +} + +func (l *richLogger) Debugf(format string, v ...any) { + l.debug(fmt.Sprintf(format, v...)) +} + +func (l *richLogger) Debugv(v any) { + l.debug(v) +} + +func (l *richLogger) Debugw(msg string, fields ...LogField) { + l.debug(msg, fields...) +} + +func (l *richLogger) Error(v ...any) { + l.err(fmt.Sprint(v...)) +} + +func (l *richLogger) Errorf(format string, v ...any) { + l.err(fmt.Sprintf(format, v...)) +} + +func (l *richLogger) Errorv(v any) { + l.err(fmt.Sprint(v)) +} + +func (l *richLogger) Errorw(msg string, fields ...LogField) { + l.err(msg, fields...) +} + +func (l *richLogger) Info(v ...any) { + l.info(fmt.Sprint(v...)) +} + +func (l *richLogger) Infof(format string, v ...any) { + l.info(fmt.Sprintf(format, v...)) +} + +func (l *richLogger) Infov(v any) { + l.info(v) +} + +func (l *richLogger) Infow(msg string, fields ...LogField) { + l.info(msg, fields...) +} + +func (l *richLogger) Slow(v ...any) { + l.slow(fmt.Sprint(v...)) +} + +func (l *richLogger) Slowf(format string, v ...any) { + l.slow(fmt.Sprintf(format, v...)) +} + +func (l *richLogger) Slowv(v any) { + l.slow(v) +} + +func (l *richLogger) Sloww(msg string, fields ...LogField) { + l.slow(msg, fields...) +} + +func (l *richLogger) WithCallerSkip(skip int) Logger { + if skip <= 0 { + return l + } + + l.callerSkip = skip + return l +} + +func (l *richLogger) WithContext(ctx context.Context) Logger { + l.ctx = ctx + return l +} + +func (l *richLogger) WithDuration(duration time.Duration) Logger { + l.fields = append(l.fields, Field(durationKey, ntime.ReprOfDuration(duration))) + return l +} + +func (l *richLogger) WithFields(fields ...LogField) Logger { + l.fields = append(l.fields, fields...) + return l +} + +func (l *richLogger) buildFields(fields ...LogField) []LogField { + fields = append(l.fields, fields...) + fields = append(fields, Field(callerKey, getCaller(CallerDepth+l.callerSkip))) + + if l.ctx == nil { + return fields + } + + traceID := traceIDFromContext(l.ctx) + if len(traceID) > 0 { + fields = append(fields, Field(traceKey, traceID)) + } + + spanID := spanIDFromContext(l.ctx) + if len(spanID) > 0 { + fields = append(fields, Field(spanKey, spanID)) + } + + val := l.ctx.Value(fieldsContextKey) + if val != nil { + if arr, ok := val.([]LogField); ok { + fields = append(fields, arr...) + } + } + + return fields +} + +func (l *richLogger) debug(v any, fields ...LogField) { + if shallLog(DebugLevel) { + GetWriter().Debug(v, l.buildFields(fields...)...) + } +} + +func (l *richLogger) err(v any, fields ...LogField) { + if shallLog(ErrorLevel) { + GetWriter().Error(v, l.buildFields(fields...)...) + } +} + +func (l *richLogger) info(v any, fields ...LogField) { + if shallLog(InfoLevel) { + GetWriter().Info(v, l.buildFields(fields...)...) + } +} + +func (l *richLogger) slow(v any, fields ...LogField) { + if shallLog(ErrorLevel) { + GetWriter().Slow(v, l.buildFields(fields...)...) + } +} diff --git a/nlog/richlogger_test.go b/nlog/richlogger_test.go new file mode 100644 index 0000000..87ac2ea --- /dev/null +++ b/nlog/richlogger_test.go @@ -0,0 +1,318 @@ +package nlog + +import ( + "context" + "encoding/json" + "fmt" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "io" + "strings" + "sync/atomic" + "testing" + "time" + + "go.opentelemetry.io/otel" + sdktrace "go.opentelemetry.io/otel/sdk/trace" +) + +func TestTraceLog(t *testing.T) { + SetLevel(InfoLevel) + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + otp := otel.GetTracerProvider() + tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample())) + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(otp) + + ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id") + defer span.End() + + WithContext(ctx).Info(testlog) + validate(t, w.String(), true, true) +} + +func TestTraceDebug(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + otp := otel.GetTracerProvider() + tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample())) + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(otp) + + ctx, span := tp.Tracer("foo").Start(context.Background(), "bar") + defer span.End() + + l := WithContext(ctx) + SetLevel(DebugLevel) + l.WithDuration(time.Second).Debug(testlog) + assert.True(t, strings.Contains(w.String(), traceKey)) + assert.True(t, strings.Contains(w.String(), spanKey)) + w.Reset() + l.WithDuration(time.Second).Debugf(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Debugv(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Debugw(testlog, Field("foo", "bar")) + validate(t, w.String(), true, true) + assert.True(t, strings.Contains(w.String(), "foo"), w.String()) + assert.True(t, strings.Contains(w.String(), "bar"), w.String()) +} + +func TestTraceError(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + otp := otel.GetTracerProvider() + tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample())) + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(otp) + + ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id") + defer span.End() + + var nilCtx context.Context + l := WithContext(context.Background()) + l = l.WithContext(nilCtx) + l = l.WithContext(ctx) + SetLevel(ErrorLevel) + l.WithDuration(time.Second).Error(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Errorf(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Errorv(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Errorw(testlog, Field("basket", "ball")) + validate(t, w.String(), true, true) + assert.True(t, strings.Contains(w.String(), "basket"), w.String()) + assert.True(t, strings.Contains(w.String(), "ball"), w.String()) +} + +func TestTraceInfo(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + otp := otel.GetTracerProvider() + tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample())) + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(otp) + + ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id") + defer span.End() + + SetLevel(InfoLevel) + l := WithContext(ctx) + l.WithDuration(time.Second).Info(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Infof(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Infov(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Infow(testlog, Field("basket", "ball")) + validate(t, w.String(), true, true) + assert.True(t, strings.Contains(w.String(), "basket"), w.String()) + assert.True(t, strings.Contains(w.String(), "ball"), w.String()) +} + +func TestTraceInfoConsole(t *testing.T) { + old := atomic.SwapUint32(&encoding, jsonEncodingType) + defer atomic.StoreUint32(&encoding, old) + + w := new(mockWriter) + o := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(o) + }() + + otp := otel.GetTracerProvider() + tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample())) + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(otp) + + ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id") + defer span.End() + + l := WithContext(ctx) + SetLevel(InfoLevel) + l.WithDuration(time.Second).Info(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Infof(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Infov(testlog) + validate(t, w.String(), true, true) +} + +func TestTraceSlow(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + otp := otel.GetTracerProvider() + tp := sdktrace.NewTracerProvider(sdktrace.WithSampler(sdktrace.AlwaysSample())) + otel.SetTracerProvider(tp) + defer otel.SetTracerProvider(otp) + + ctx, span := tp.Tracer("trace-id").Start(context.Background(), "span-id") + defer span.End() + + l := WithContext(ctx) + SetLevel(InfoLevel) + l.WithDuration(time.Second).Slow(testlog) + assert.True(t, strings.Contains(w.String(), traceKey)) + assert.True(t, strings.Contains(w.String(), spanKey)) + w.Reset() + l.WithDuration(time.Second).Slowf(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Slowv(testlog) + validate(t, w.String(), true, true) + w.Reset() + l.WithDuration(time.Second).Sloww(testlog, Field("basket", "ball")) + validate(t, w.String(), true, true) + assert.True(t, strings.Contains(w.String(), "basket"), w.String()) + assert.True(t, strings.Contains(w.String(), "ball"), w.String()) +} + +func TestTraceWithoutContext(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + l := WithContext(context.Background()) + SetLevel(InfoLevel) + l.WithDuration(time.Second).Info(testlog) + validate(t, w.String(), false, false) + w.Reset() + l.WithDuration(time.Second).Infof(testlog) + validate(t, w.String(), false, false) +} + +func TestLogWithFields(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + ctx := ContextWithFields(context.Background(), Field("foo", "bar")) + l := WithContext(ctx) + SetLevel(InfoLevel) + l.Info(testlog) + + var val mockValue + assert.Nil(t, json.Unmarshal([]byte(w.String()), &val)) + assert.Equal(t, "bar", val.Foo) +} + +func TestLogWithCallerSkip(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + l := WithCallerSkip(1).WithCallerSkip(0) + p := func(v string) { + l.Info(v) + } + + file, line := getFileLine() + p(testlog) + assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) + + w.Reset() + l = WithCallerSkip(0).WithCallerSkip(1) + file, line = getFileLine() + p(testlog) + assert.True(t, w.Contains(fmt.Sprintf("%s:%d", file, line+1))) +} + +func TestLoggerWithFields(t *testing.T) { + w := new(mockWriter) + old := writer.Swap(w) + writer.lock.RLock() + defer func() { + writer.lock.RUnlock() + writer.Store(old) + }() + + l := WithContext(context.Background()).WithFields(Field("foo", "bar")) + l.Info(testlog) + + var val mockValue + assert.Nil(t, json.Unmarshal([]byte(w.String()), &val)) + assert.Equal(t, "bar", val.Foo) +} + +func validate(t *testing.T, body string, expectedTrace, expectedSpan bool) { + var val mockValue + dec := json.NewDecoder(strings.NewReader(body)) + + for { + var doc mockValue + err := dec.Decode(&doc) + if err == io.EOF { + // all done + break + } + if err != nil { + continue + } + + val = doc + } + + assert.Equal(t, expectedTrace, len(val.Trace) > 0, body) + assert.Equal(t, expectedSpan, len(val.Span) > 0, body) +} + +type mockValue struct { + Trace string `json:"trace"` + Span string `json:"span"` + Foo string `json:"foo"` +} diff --git a/nlog/rotatelogger.go b/nlog/rotatelogger.go new file mode 100644 index 0000000..f5627fe --- /dev/null +++ b/nlog/rotatelogger.go @@ -0,0 +1,443 @@ +package nlog + +import ( + "compress/gzip" + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/nfs" + "io" + "log" + "os" + "path" + "path/filepath" + "sort" + "strings" + "sync" + "time" +) + +const ( + dateFormat = "2006-01-02" + fileTimeFormat = time.RFC3339 + hoursPerDay = 24 + bufferSize = 100 + defaultDirMode = 0o755 + defaultFileMode = 0o600 + gzipExt = ".gz" + megaBytes = 1 << 20 +) + +// ErrLogFileClosed is an error that indicates the log file is already closed. +var ErrLogFileClosed = errors.New("error: log file closed") + +type ( + // A RotateRule interface is used to define the log rotating rules. + RotateRule interface { + BackupFileName() string + MarkRotated() + OutdatedFiles() []string + ShallRotate(size int64) bool + } + + // A RotateLogger is a Logger that can rotate log files with given rules. + RotateLogger struct { + filename string + backup string + fp *os.File + channel chan []byte + done chan struct{} + rule RotateRule + compress bool + // can't use threading.RoutineGroup because of cycle import + waitGroup sync.WaitGroup + closeOnce sync.Once + currentSize int64 + } + + // A DailyRotateRule is a rule to daily rotate the log files. + DailyRotateRule struct { + rotatedTime string + filename string + delimiter string + days int + gzip bool + } + + // SizeLimitRotateRule a rotation rule that make the log file rotated base on size + SizeLimitRotateRule struct { + DailyRotateRule + maxSize int64 + maxBackups int + } +) + +// DefaultRotateRule is a default log rotating rule, currently DailyRotateRule. +func DefaultRotateRule(filename, delimiter string, days int, gzip bool) RotateRule { + return &DailyRotateRule{ + rotatedTime: getNowDate(), + filename: filename, + delimiter: delimiter, + days: days, + gzip: gzip, + } +} + +// BackupFileName returns the backup filename on rotating. +func (r *DailyRotateRule) BackupFileName() string { + return fmt.Sprintf("%s%s%s", r.filename, r.delimiter, getNowDate()) +} + +// MarkRotated marks the rotated time of r to be the current time. +func (r *DailyRotateRule) MarkRotated() { + r.rotatedTime = getNowDate() +} + +// OutdatedFiles returns the files that exceeded the keeping days. +func (r *DailyRotateRule) OutdatedFiles() []string { + if r.days <= 0 { + return nil + } + + var pattern string + if r.gzip { + pattern = fmt.Sprintf("%s%s*%s", r.filename, r.delimiter, gzipExt) + } else { + pattern = fmt.Sprintf("%s%s*", r.filename, r.delimiter) + } + + files, err := filepath.Glob(pattern) + if err != nil { + Errorf("failed to delete outdated log files, error: %s", err) + return nil + } + + var buf strings.Builder + boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(dateFormat) + buf.WriteString(r.filename) + buf.WriteString(r.delimiter) + buf.WriteString(boundary) + if r.gzip { + buf.WriteString(gzipExt) + } + boundaryFile := buf.String() + + var outdates []string + for _, file := range files { + if file < boundaryFile { + outdates = append(outdates, file) + } + } + + return outdates +} + +// ShallRotate checks if the file should be rotated. +func (r *DailyRotateRule) ShallRotate(_ int64) bool { + return len(r.rotatedTime) > 0 && getNowDate() != r.rotatedTime +} + +// NewSizeLimitRotateRule returns the rotation rule with size limit +func NewSizeLimitRotateRule(filename, delimiter string, days, maxSize, maxBackups int, gzip bool) RotateRule { + return &SizeLimitRotateRule{ + DailyRotateRule: DailyRotateRule{ + rotatedTime: getNowDateInRFC3339Format(), + filename: filename, + delimiter: delimiter, + days: days, + gzip: gzip, + }, + maxSize: int64(maxSize) * megaBytes, + maxBackups: maxBackups, + } +} + +func (r *SizeLimitRotateRule) BackupFileName() string { + dir := filepath.Dir(r.filename) + prefix, ext := r.parseFilename() + timestamp := getNowDateInRFC3339Format() + return filepath.Join(dir, fmt.Sprintf("%s%s%s%s", prefix, r.delimiter, timestamp, ext)) +} + +func (r *SizeLimitRotateRule) MarkRotated() { + r.rotatedTime = getNowDateInRFC3339Format() +} + +func (r *SizeLimitRotateRule) OutdatedFiles() []string { + dir := filepath.Dir(r.filename) + prefix, ext := r.parseFilename() + + var pattern string + if r.gzip { + pattern = fmt.Sprintf("%s%s%s%s*%s%s", dir, string(filepath.Separator), + prefix, r.delimiter, ext, gzipExt) + } else { + pattern = fmt.Sprintf("%s%s%s%s*%s", dir, string(filepath.Separator), + prefix, r.delimiter, ext) + } + + files, err := filepath.Glob(pattern) + if err != nil { + Errorf("failed to delete outdated log files, error: %s", err) + return nil + } + + sort.Strings(files) + + outdated := make(map[string]struct{}) + + // test if too many backups + if r.maxBackups > 0 && len(files) > r.maxBackups { + for _, f := range files[:len(files)-r.maxBackups] { + outdated[f] = struct{}{} + } + files = files[len(files)-r.maxBackups:] + } + + // test if any too old backups + if r.days > 0 { + boundary := time.Now().Add(-time.Hour * time.Duration(hoursPerDay*r.days)).Format(fileTimeFormat) + boundaryFile := filepath.Join(dir, fmt.Sprintf("%s%s%s%s", prefix, r.delimiter, boundary, ext)) + if r.gzip { + boundaryFile += gzipExt + } + for _, f := range files { + if f >= boundaryFile { + break + } + outdated[f] = struct{}{} + } + } + + var result []string + for k := range outdated { + result = append(result, k) + } + return result +} + +func (r *SizeLimitRotateRule) ShallRotate(size int64) bool { + return r.maxSize > 0 && r.maxSize < size +} + +func (r *SizeLimitRotateRule) parseFilename() (prefix, ext string) { + logName := filepath.Base(r.filename) + ext = filepath.Ext(r.filename) + prefix = logName[:len(logName)-len(ext)] + return +} + +// NewLogger returns a RotateLogger with given filename and rule, etc. +func NewLogger(filename string, rule RotateRule, compress bool) (*RotateLogger, error) { + l := &RotateLogger{ + filename: filename, + channel: make(chan []byte, bufferSize), + done: make(chan struct{}), + rule: rule, + compress: compress, + } + if err := l.init(); err != nil { + return nil, err + } + + l.startWorker() + return l, nil +} + +// Close closes l. +func (l *RotateLogger) Close() error { + var err error + + l.closeOnce.Do(func() { + close(l.done) + l.waitGroup.Wait() + + if err = l.fp.Sync(); err != nil { + return + } + + err = l.fp.Close() + }) + + return err +} + +func (l *RotateLogger) Write(data []byte) (int, error) { + select { + case l.channel <- data: + return len(data), nil + case <-l.done: + log.Println(string(data)) + return 0, ErrLogFileClosed + } +} + +func (l *RotateLogger) getBackupFilename() string { + if len(l.backup) == 0 { + return l.rule.BackupFileName() + } + + return l.backup +} + +func (l *RotateLogger) init() error { + l.backup = l.rule.BackupFileName() + + if fileInfo, err := os.Stat(l.filename); err != nil { + basePath := path.Dir(l.filename) + if _, err = os.Stat(basePath); err != nil { + if err = os.MkdirAll(basePath, defaultDirMode); err != nil { + return err + } + } + + if l.fp, err = os.Create(l.filename); err != nil { + return err + } + } else { + if l.fp, err = os.OpenFile(l.filename, os.O_APPEND|os.O_WRONLY, defaultFileMode); err != nil { + return err + } + l.currentSize = fileInfo.Size() + } + + nfs.CloseOnExec(l.fp) + + return nil +} + +func (l *RotateLogger) maybeCompressFile(file string) { + if !l.compress { + return + } + + defer func() { + if r := recover(); r != nil { + ErrorStack(r) + } + }() + + if _, err := os.Stat(file); err != nil { + // file not exists or other error, ignore compression + return + } + + compressLogFile(file) +} + +func (l *RotateLogger) maybeDeleteOutdatedFiles() { + files := l.rule.OutdatedFiles() + for _, file := range files { + if err := os.Remove(file); err != nil { + Errorf("failed to remove outdated file: %s", file) + } + } +} + +func (l *RotateLogger) postRotate(file string) { + go func() { + // we cannot use threading.GoSafe here, because of import cycle. + l.maybeCompressFile(file) + l.maybeDeleteOutdatedFiles() + }() +} + +func (l *RotateLogger) rotate() error { + if l.fp != nil { + err := l.fp.Close() + l.fp = nil + if err != nil { + return err + } + } + + _, err := os.Stat(l.filename) + if err == nil && len(l.backup) > 0 { + backupFilename := l.getBackupFilename() + err = os.Rename(l.filename, backupFilename) + if err != nil { + return err + } + + l.postRotate(backupFilename) + } + + l.backup = l.rule.BackupFileName() + if l.fp, err = os.Create(l.filename); err == nil { + nfs.CloseOnExec(l.fp) + } + + return err +} + +func (l *RotateLogger) startWorker() { + l.waitGroup.Add(1) + + go func() { + defer l.waitGroup.Done() + + for { + select { + case event := <-l.channel: + l.write(event) + case <-l.done: + return + } + } + }() +} + +func (l *RotateLogger) write(v []byte) { + if l.rule.ShallRotate(l.currentSize + int64(len(v))) { + if err := l.rotate(); err != nil { + log.Println(err) + } else { + l.rule.MarkRotated() + l.currentSize = 0 + } + } + if l.fp != nil { + l.fp.Write(v) + l.currentSize += int64(len(v)) + } +} + +func compressLogFile(file string) { + start := time.Now() + Infof("compressing log file: %s", file) + if err := gzipFile(file); err != nil { + Errorf("compress error: %s", err) + } else { + Infof("compressed log file: %s, took %s", file, time.Since(start)) + } +} + +func getNowDate() string { + return time.Now().Format(dateFormat) +} + +func getNowDateInRFC3339Format() string { + return time.Now().Format(fileTimeFormat) +} + +func gzipFile(file string) error { + in, err := os.Open(file) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(fmt.Sprintf("%s%s", file, gzipExt)) + if err != nil { + return err + } + defer out.Close() + + w := gzip.NewWriter(out) + if _, err = io.Copy(w, in); err != nil { + return err + } else if err = w.Close(); err != nil { + return err + } + + return os.Remove(file) +} diff --git a/nlog/rotatelogger_test.go b/nlog/rotatelogger_test.go new file mode 100644 index 0000000..b0704e5 --- /dev/null +++ b/nlog/rotatelogger_test.go @@ -0,0 +1,348 @@ +package nlog + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/nrandom" + "os" + "path/filepath" + "syscall" + "testing" + "time" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +func TestDailyRotateRuleMarkRotated(t *testing.T) { + var rule DailyRotateRule + rule.MarkRotated() + assert.Equal(t, getNowDate(), rule.rotatedTime) +} + +func TestDailyRotateRuleOutdatedFiles(t *testing.T) { + var rule DailyRotateRule + assert.Empty(t, rule.OutdatedFiles()) + rule.days = 1 + assert.Empty(t, rule.OutdatedFiles()) + rule.gzip = true + assert.Empty(t, rule.OutdatedFiles()) +} + +func TestDailyRotateRuleShallRotate(t *testing.T) { + var rule DailyRotateRule + rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(dateFormat) + assert.True(t, rule.ShallRotate(0)) +} + +func TestSizeLimitRotateRuleMarkRotated(t *testing.T) { + var rule SizeLimitRotateRule + rule.MarkRotated() + assert.Equal(t, getNowDateInRFC3339Format(), rule.rotatedTime) +} + +func TestSizeLimitRotateRuleOutdatedFiles(t *testing.T) { + var rule SizeLimitRotateRule + assert.Empty(t, rule.OutdatedFiles()) + rule.days = 1 + assert.Empty(t, rule.OutdatedFiles()) + rule.gzip = true + assert.Empty(t, rule.OutdatedFiles()) + rule.maxBackups = 0 + assert.Empty(t, rule.OutdatedFiles()) +} + +func TestSizeLimitRotateRuleShallRotate(t *testing.T) { + var rule SizeLimitRotateRule + rule.rotatedTime = time.Now().Add(time.Hour * 24).Format(fileTimeFormat) + rule.maxSize = 0 + assert.False(t, rule.ShallRotate(0)) + rule.maxSize = 100 + assert.False(t, rule.ShallRotate(0)) + assert.True(t, rule.ShallRotate(101*megaBytes)) +} + +func TestRotateLoggerClose(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filename) + } + logger, err := NewLogger(filename, new(DailyRotateRule), false) + assert.Nil(t, err) + assert.Nil(t, logger.Close()) +} + +func TestRotateLoggerGetBackupFilename(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filename) + } + logger, err := NewLogger(filename, new(DailyRotateRule), false) + assert.Nil(t, err) + assert.True(t, len(logger.getBackupFilename()) > 0) + logger.backup = "" + assert.True(t, len(logger.getBackupFilename()) > 0) +} + +func TestRotateLoggerMayCompressFile(t *testing.T) { + old := os.Stdout + os.Stdout = os.NewFile(0, os.DevNull) + defer func() { + os.Stdout = old + }() + + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filename) + } + logger, err := NewLogger(filename, new(DailyRotateRule), false) + assert.Nil(t, err) + logger.maybeCompressFile(filename) + _, err = os.Stat(filename) + assert.Nil(t, err) +} + +func TestRotateLoggerMayCompressFileTrue(t *testing.T) { + old := os.Stdout + os.Stdout = os.NewFile(0, os.DevNull) + defer func() { + os.Stdout = old + }() + + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + logger, err := NewLogger(filename, new(DailyRotateRule), true) + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz") + } + logger.maybeCompressFile(filename) + _, err = os.Stat(filename) + assert.NotNil(t, err) +} + +func TestRotateLoggerRotate(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + logger, err := NewLogger(filename, new(DailyRotateRule), true) + assert.Nil(t, err) + if len(filename) > 0 { + defer func() { + os.Remove(logger.getBackupFilename()) + os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz") + }() + } + err = logger.rotate() + switch v := err.(type) { + case *os.LinkError: + // avoid rename error on docker container + assert.Equal(t, syscall.EXDEV, v.Err) + case *os.PathError: + // ignore remove error for tests, + // files are cleaned in GitHub actions. + assert.Equal(t, "remove", v.Op) + default: + assert.Nil(t, err) + } +} + +func TestRotateLoggerWrite(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + rule := new(DailyRotateRule) + logger, err := NewLogger(filename, rule, true) + assert.Nil(t, err) + if len(filename) > 0 { + defer func() { + os.Remove(logger.getBackupFilename()) + os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz") + }() + } + // the following write calls cannot be changed to Write, because of DATA RACE. + logger.write([]byte(`foo`)) + rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat) + logger.write([]byte(`bar`)) + logger.Close() + logger.write([]byte(`baz`)) +} + +func TestLogWriterClose(t *testing.T) { + assert.Nil(t, newLogWriter(nil).Close()) +} + +func TestRotateLoggerWithSizeLimitRotateRuleClose(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filename) + } + logger, err := NewLogger(filename, new(SizeLimitRotateRule), false) + assert.Nil(t, err) + assert.Nil(t, logger.Close()) +} + +func TestRotateLoggerGetBackupWithSizeLimitRotateRuleFilename(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filename) + } + logger, err := NewLogger(filename, new(SizeLimitRotateRule), false) + assert.Nil(t, err) + assert.True(t, len(logger.getBackupFilename()) > 0) + logger.backup = "" + assert.True(t, len(logger.getBackupFilename()) > 0) +} + +func TestRotateLoggerWithSizeLimitRotateRuleMayCompressFile(t *testing.T) { + old := os.Stdout + os.Stdout = os.NewFile(0, os.DevNull) + defer func() { + os.Stdout = old + }() + + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filename) + } + logger, err := NewLogger(filename, new(SizeLimitRotateRule), false) + assert.Nil(t, err) + logger.maybeCompressFile(filename) + _, err = os.Stat(filename) + assert.Nil(t, err) +} + +func TestRotateLoggerWithSizeLimitRotateRuleMayCompressFileTrue(t *testing.T) { + old := os.Stdout + os.Stdout = os.NewFile(0, os.DevNull) + defer func() { + os.Stdout = old + }() + + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + logger, err := NewLogger(filename, new(SizeLimitRotateRule), true) + assert.Nil(t, err) + if len(filename) > 0 { + defer os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz") + } + logger.maybeCompressFile(filename) + _, err = os.Stat(filename) + assert.NotNil(t, err) +} + +func TestRotateLoggerWithSizeLimitRotateRuleMayCompressFileFailed(t *testing.T) { + old := os.Stdout + os.Stdout = os.NewFile(0, os.DevNull) + defer func() { + os.Stdout = old + }() + + filename := nrandom.NewUUIDV7().String() + logger, err := NewLogger(filename, new(SizeLimitRotateRule), true) + defer os.Remove(filename) + if assert.NoError(t, err) { + assert.NotPanics(t, func() { + logger.maybeCompressFile(nrandom.NewUUIDV7().String()) + }) + } +} + +func TestRotateLoggerWithSizeLimitRotateRuleRotate(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + logger, err := NewLogger(filename, new(SizeLimitRotateRule), true) + assert.Nil(t, err) + if len(filename) > 0 { + defer func() { + os.Remove(logger.getBackupFilename()) + os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz") + }() + } + err = logger.rotate() + switch v := err.(type) { + case *os.LinkError: + // avoid rename error on docker container + assert.Equal(t, syscall.EXDEV, v.Err) + case *os.PathError: + // ignore remove error for tests, + // files are cleaned in GitHub actions. + assert.Equal(t, "remove", v.Op) + default: + assert.Nil(t, err) + } +} + +func TestRotateLoggerWithSizeLimitRotateRuleWrite(t *testing.T) { + filename, err := nfs.OSTempFilenameWithContent("foo") + assert.Nil(t, err) + rule := new(SizeLimitRotateRule) + logger, err := NewLogger(filename, rule, true) + assert.Nil(t, err) + if len(filename) > 0 { + defer func() { + os.Remove(logger.getBackupFilename()) + os.Remove(filepath.Base(logger.getBackupFilename()) + ".gz") + }() + } + // the following write calls cannot be changed to Write, because of DATA RACE. + logger.write([]byte(`foo`)) + rule.rotatedTime = time.Now().Add(-time.Hour * 24).Format(dateFormat) + logger.write([]byte(`bar`)) + logger.Close() + logger.write([]byte(`baz`)) +} + +func BenchmarkRotateLogger(b *testing.B) { + filename := "./test.log" + filename2 := "./test2.log" + dailyRotateRuleLogger, err1 := NewLogger( + filename, + DefaultRotateRule( + filename, + backupFileDelimiter, + 1, + true, + ), + true, + ) + if err1 != nil { + b.Logf("Failed to new daily rotate rule logger: %v", err1) + b.FailNow() + } + sizeLimitRotateRuleLogger, err2 := NewLogger( + filename2, + NewSizeLimitRotateRule( + filename, + backupFileDelimiter, + 1, + 100, + 10, + true, + ), + true, + ) + if err2 != nil { + b.Logf("Failed to new size limit rotate rule logger: %v", err1) + b.FailNow() + } + defer func() { + dailyRotateRuleLogger.Close() + sizeLimitRotateRuleLogger.Close() + os.Remove(filename) + os.Remove(filename2) + }() + + b.Run("daily rotate rule", func(b *testing.B) { + for i := 0; i < b.N; i++ { + dailyRotateRuleLogger.write([]byte("testing\ntesting\n")) + } + }) + b.Run("size limit rotate rule", func(b *testing.B) { + for i := 0; i < b.N; i++ { + sizeLimitRotateRuleLogger.write([]byte("testing\ntesting\n")) + } + }) +} diff --git a/nlog/syslog.go b/nlog/syslog.go new file mode 100644 index 0000000..ad63d2c --- /dev/null +++ b/nlog/syslog.go @@ -0,0 +1,15 @@ +package nlog + +import "log" + +type redirector struct{} + +// CollectSysLog redirects system log into logx info +func CollectSysLog() { + log.SetOutput(new(redirector)) +} + +func (r *redirector) Write(p []byte) (n int, err error) { + Info(string(p)) + return len(p), nil +} diff --git a/nlog/syslog_test.go b/nlog/syslog_test.go new file mode 100644 index 0000000..7bed8f4 --- /dev/null +++ b/nlog/syslog_test.go @@ -0,0 +1,59 @@ +package nlog + +import ( + "encoding/json" + "log" + "strings" + "sync/atomic" + "testing" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +const testlog = "Stay hungry, stay foolish." + +func TestCollectSysLog(t *testing.T) { + CollectSysLog() + content := getContent(captureOutput(func() { + log.Print(testlog) + })) + assert.True(t, strings.Contains(content, testlog)) +} + +func TestRedirector(t *testing.T) { + var r redirector + content := getContent(captureOutput(func() { + r.Write([]byte(testlog)) + })) + assert.Equal(t, testlog, content) +} + +func captureOutput(f func()) string { + w := new(mockWriter) + old := writer.Swap(w) + defer writer.Store(old) + + prevLevel := atomic.LoadUint32(&logLevel) + SetLevel(InfoLevel) + f() + SetLevel(prevLevel) + + return w.String() +} + +func getContent(jsonStr string) string { + var entry map[string]any + json.Unmarshal([]byte(jsonStr), &entry) + + val, ok := entry[contentKey] + if !ok { + return "" + } + + str, ok := val.(string) + if !ok { + return "" + } + + return str +} diff --git a/nlog/util.go b/nlog/util.go new file mode 100644 index 0000000..8589938 --- /dev/null +++ b/nlog/util.go @@ -0,0 +1,55 @@ +package nlog + +import ( + "context" + "fmt" + "go.opentelemetry.io/otel/trace" + "runtime" + "strings" + "time" +) + +func spanIDFromContext(ctx context.Context) string { + spanCtx := trace.SpanContextFromContext(ctx) + if spanCtx.HasSpanID() { + return spanCtx.SpanID().String() + } + + return "" +} + +func traceIDFromContext(ctx context.Context) string { + spanCtx := trace.SpanContextFromContext(ctx) + if spanCtx.HasTraceID() { + return spanCtx.TraceID().String() + } + + return "" +} + +func getCaller(callDepth int) string { + _, file, line, ok := runtime.Caller(callDepth) + if !ok { + return "" + } + + return prettyCaller(file, line) +} + +func getTimestamp() string { + return time.Now().Format(timeFormat) +} + +func prettyCaller(file string, line int) string { + idx := strings.LastIndexByte(file, '/') + if idx < 0 { + return fmt.Sprintf("%s:%d", file, line) + } + + idx = strings.LastIndexByte(file[:idx], '/') + if idx < 0 { + return fmt.Sprintf("%s:%d", file, line) + } + + return fmt.Sprintf("%s:%d", file[idx+1:], line) +} diff --git a/nlog/util_test.go b/nlog/util_test.go new file mode 100644 index 0000000..fb328f3 --- /dev/null +++ b/nlog/util_test.go @@ -0,0 +1,72 @@ +package nlog + +import ( + "path/filepath" + "runtime" + "testing" + "time" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +func TestGetCaller(t *testing.T) { + _, file, _, _ := runtime.Caller(0) + assert.Contains(t, getCaller(1), filepath.Base(file)) + assert.True(t, len(getCaller(1<<10)) == 0) +} + +func TestGetTimestamp(t *testing.T) { + ts := getTimestamp() + tm, err := time.Parse(timeFormat, ts) + assert.Nil(t, err) + assert.True(t, time.Since(tm) < time.Minute) +} + +func TestPrettyCaller(t *testing.T) { + tests := []struct { + name string + file string + line int + want string + }{ + { + name: "regular", + file: "logx_test.go", + line: 123, + want: "logx_test.go:123", + }, + { + name: "relative", + file: "adhoc/logx_test.go", + line: 123, + want: "adhoc/logx_test.go:123", + }, + { + name: "long path", + file: "github.com/zeromicro/go-zero/core/logx/util_test.go", + line: 12, + want: "logx/util_test.go:12", + }, + { + name: "local path", + file: "/Users/kevin/go-zero/core/logx/util_test.go", + line: 1234, + want: "logx/util_test.go:1234", + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.want, prettyCaller(test.file, test.line)) + }) + } +} + +func BenchmarkGetCaller(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + getCaller(1) + } +} diff --git a/nlog/vars.go b/nlog/vars.go new file mode 100644 index 0000000..50bc536 --- /dev/null +++ b/nlog/vars.go @@ -0,0 +1,66 @@ +package nlog + +import "errors" + +const ( + // DebugLevel logs everything + DebugLevel uint32 = iota + // InfoLevel does not include debugs + InfoLevel + // ErrorLevel includes errors, slows, stacks + ErrorLevel + // SevereLevel only log severe messages + SevereLevel +) + +const ( + jsonEncodingType = iota + plainEncodingType +) + +const ( + plainEncoding = "plain" + plainEncodingSep = '\t' + sizeRotationRule = "size" + + accessFilename = "access.log" + errorFilename = "error.log" + severeFilename = "severe.log" + slowFilename = "slow.log" + statFilename = "stat.log" + + fileMode = "file" + volumeMode = "volume" + + levelAlert = "alert" + levelInfo = "info" + levelError = "error" + levelSevere = "severe" + levelFatal = "fatal" + levelSlow = "slow" + levelStat = "stat" + levelDebug = "debug" + + backupFileDelimiter = "-" + flags = 0x0 +) + +const ( + callerKey = "caller" + contentKey = "content" + durationKey = "duration" + levelKey = "level" + spanKey = "span" + timestampKey = "@timestamp" + traceKey = "trace" + truncatedKey = "truncated" +) + +var ( + // ErrLogPathNotSet is an error that indicates the log path is not set. + ErrLogPathNotSet = errors.New("log path must be set") + // ErrLogServiceNameNotSet is an error that indicates that the service name is not set. + ErrLogServiceNameNotSet = errors.New("log service name must be set") + + truncatedField = Field(truncatedKey, true) +) diff --git a/nlog/writer.go b/nlog/writer.go new file mode 100644 index 0000000..ff56b4d --- /dev/null +++ b/nlog/writer.go @@ -0,0 +1,403 @@ +package nlog + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gookit/color" + "github.com/mattn/go-colorable" + "io" + "log" + "path" + "sync" + "sync/atomic" +) + +type ( + Writer interface { + Alert(v any) + Close() error + Debug(v any, fields ...LogField) + Error(v any, fields ...LogField) + Info(v any, fields ...LogField) + Severe(v any) + Slow(v any, fields ...LogField) + Stack(v any) + Stat(v any, fields ...LogField) + } + + atomicWriter struct { + writer Writer + lock sync.RWMutex + } + + concreteWriter struct { + infoLog io.WriteCloser + errorLog io.WriteCloser + severeLog io.WriteCloser + slowLog io.WriteCloser + statLog io.WriteCloser + stackLog io.Writer + } +) + +// NewWriter creates a new Writer with the given io.Writer. +func NewWriter(w io.Writer) Writer { + lw := newLogWriter(log.New(w, "", flags)) + + return &concreteWriter{ + infoLog: lw, + errorLog: lw, + severeLog: lw, + slowLog: lw, + statLog: lw, + stackLog: lw, + } +} + +func (w *atomicWriter) Load() Writer { + w.lock.RLock() + defer w.lock.RUnlock() + return w.writer +} + +func (w *atomicWriter) Store(v Writer) { + w.lock.Lock() + defer w.lock.Unlock() + w.writer = v +} + +func (w *atomicWriter) StoreIfNil(v Writer) Writer { + w.lock.Lock() + defer w.lock.Unlock() + + if w.writer == nil { + w.writer = v + } + + return w.writer +} + +func (w *atomicWriter) Swap(v Writer) Writer { + w.lock.Lock() + defer w.lock.Unlock() + old := w.writer + w.writer = v + return old +} + +func newConsoleWriter() Writer { + outLog := newLogWriter(log.New(colorable.NewColorableStdout(), "", flags)) + errLog := newLogWriter(log.New(colorable.NewColorableStderr(), "", flags)) + return &concreteWriter{ + infoLog: outLog, + errorLog: errLog, + severeLog: errLog, + slowLog: errLog, + stackLog: newLessWriter(errLog, options.logStackCooldownMills), + statLog: outLog, + } +} + +func newFileWriter(c LogConf) (Writer, error) { + var err error + var opts []LogOption + var infoLog io.WriteCloser + var errorLog io.WriteCloser + var severeLog io.WriteCloser + var slowLog io.WriteCloser + var statLog io.WriteCloser + var stackLog io.Writer + + if len(c.Path) == 0 { + return nil, ErrLogPathNotSet + } + + opts = append(opts, WithCooldownMillis(c.StackCooldownMillis)) + if c.Compress { + opts = append(opts, WithGzip()) + } + if c.KeepDays > 0 { + opts = append(opts, WithKeepDays(c.KeepDays)) + } + if c.MaxBackups > 0 { + opts = append(opts, WithMaxBackups(c.MaxBackups)) + } + if c.MaxSize > 0 { + opts = append(opts, WithMaxSize(c.MaxSize)) + } + + opts = append(opts, WithRotation(c.Rotation)) + + accessFile := path.Join(c.Path, accessFilename) + errorFile := path.Join(c.Path, errorFilename) + severeFile := path.Join(c.Path, severeFilename) + slowFile := path.Join(c.Path, slowFilename) + statFile := path.Join(c.Path, statFilename) + + handleOptions(opts) + setupLogLevel(c) + + if infoLog, err = createOutput(accessFile); err != nil { + return nil, err + } + + if errorLog, err = createOutput(errorFile); err != nil { + return nil, err + } + + if severeLog, err = createOutput(severeFile); err != nil { + return nil, err + } + + if slowLog, err = createOutput(slowFile); err != nil { + return nil, err + } + + if statLog, err = createOutput(statFile); err != nil { + return nil, err + } + + stackLog = newLessWriter(errorLog, options.logStackCooldownMills) + + return &concreteWriter{ + infoLog: infoLog, + errorLog: errorLog, + severeLog: severeLog, + slowLog: slowLog, + statLog: statLog, + stackLog: stackLog, + }, nil +} + +func (w *concreteWriter) Alert(v any) { + output(w.errorLog, levelAlert, v) +} + +func (w *concreteWriter) Close() error { + if err := w.infoLog.Close(); err != nil { + return err + } + + if err := w.errorLog.Close(); err != nil { + return err + } + + if err := w.severeLog.Close(); err != nil { + return err + } + + if err := w.slowLog.Close(); err != nil { + return err + } + + return w.statLog.Close() +} + +func (w *concreteWriter) Debug(v any, fields ...LogField) { + output(w.infoLog, levelDebug, v, fields...) +} + +func (w *concreteWriter) Error(v any, fields ...LogField) { + output(w.errorLog, levelError, v, fields...) +} + +func (w *concreteWriter) Info(v any, fields ...LogField) { + output(w.infoLog, levelInfo, v, fields...) +} + +func (w *concreteWriter) Severe(v any) { + output(w.severeLog, levelFatal, v) +} + +func (w *concreteWriter) Slow(v any, fields ...LogField) { + output(w.slowLog, levelSlow, v, fields...) +} + +func (w *concreteWriter) Stack(v any) { + output(w.stackLog, levelError, v) +} + +func (w *concreteWriter) Stat(v any, fields ...LogField) { + output(w.statLog, levelStat, v, fields...) +} + +type nopWriter struct{} + +func (n nopWriter) Alert(_ any) { +} + +func (n nopWriter) Close() error { + return nil +} + +func (n nopWriter) Debug(_ any, _ ...LogField) { +} + +func (n nopWriter) Error(_ any, _ ...LogField) { +} + +func (n nopWriter) Info(_ any, _ ...LogField) { +} + +func (n nopWriter) Severe(_ any) { +} + +func (n nopWriter) Slow(_ any, _ ...LogField) { +} + +func (n nopWriter) Stack(_ any) { +} + +func (n nopWriter) Stat(_ any, _ ...LogField) { +} + +func buildPlainFields(fields ...LogField) []string { + var items []string + + for _, field := range fields { + items = append(items, fmt.Sprintf("%s=%+v", field.Key, field.Value)) + } + + return items +} + +func combineGlobalFields(fields []LogField) []LogField { + globals := globalFields.Load() + if globals == nil { + return fields + } + + gf := globals.([]LogField) + ret := make([]LogField, 0, len(gf)+len(fields)) + ret = append(ret, gf...) + ret = append(ret, fields...) + + return ret +} + +func output(writer io.Writer, level string, val any, fields ...LogField) { + // only truncate string content, don't know how to truncate the values of other types. + if v, ok := val.(string); ok { + maxLen := atomic.LoadUint32(&maxContentLength) + if maxLen > 0 && len(v) > int(maxLen) { + val = v[:maxLen] + fields = append(fields, truncatedField) + } + } + + fields = combineGlobalFields(fields) + + switch atomic.LoadUint32(&encoding) { + case plainEncodingType: + writePlainAny(writer, level, val, buildPlainFields(fields...)...) + default: + entry := make(logEntry) + for _, field := range fields { + entry[field.Key] = field.Value + } + entry[timestampKey] = getTimestamp() + entry[levelKey] = level + entry[contentKey] = val + writeJson(writer, entry) + } +} + +func wrapLevelWithColor(level string) string { + var colour color.Color + switch level { + case levelAlert: + colour = color.FgRed + case levelError: + colour = color.FgRed + case levelFatal: + colour = color.FgRed + case levelInfo: + colour = color.FgBlue + case levelSlow: + colour = color.FgYellow + case levelDebug: + colour = color.FgYellow + case levelStat: + colour = color.FgGreen + } + + if colour == color.Normal { + return level + } + + return WithColorPadding(level, colour) +} + +func writeJson(writer io.Writer, info any) { + if content, err := json.Marshal(info); err != nil { + log.Println(err.Error()) + } else if writer == nil { + log.Println(string(content)) + } else { + writer.Write(append(content, '\n')) + } +} + +func writePlainAny(writer io.Writer, level string, val any, fields ...string) { + level = wrapLevelWithColor(level) + + switch v := val.(type) { + case string: + writePlainText(writer, level, v, fields...) + case error: + writePlainText(writer, level, v.Error(), fields...) + case fmt.Stringer: + writePlainText(writer, level, v.String(), fields...) + default: + writePlainValue(writer, level, v, fields...) + } +} + +func writePlainText(writer io.Writer, level, msg string, fields ...string) { + var buf bytes.Buffer + buf.WriteString(getTimestamp()) + buf.WriteByte(plainEncodingSep) + buf.WriteString(level) + buf.WriteByte(plainEncodingSep) + buf.WriteString(msg) + for _, item := range fields { + buf.WriteByte(plainEncodingSep) + buf.WriteString(item) + } + buf.WriteByte('\n') + if writer == nil { + log.Println(buf.String()) + return + } + + if _, err := writer.Write(buf.Bytes()); err != nil { + log.Println(err.Error()) + } +} + +func writePlainValue(writer io.Writer, level string, val any, fields ...string) { + var buf bytes.Buffer + buf.WriteString(getTimestamp()) + buf.WriteByte(plainEncodingSep) + buf.WriteString(level) + buf.WriteByte(plainEncodingSep) + if err := json.NewEncoder(&buf).Encode(val); err != nil { + log.Println(err.Error()) + return + } + + for _, item := range fields { + buf.WriteByte(plainEncodingSep) + buf.WriteString(item) + } + buf.WriteByte('\n') + if writer == nil { + log.Println(buf.String()) + return + } + + if _, err := writer.Write(buf.Bytes()); err != nil { + log.Println(err.Error()) + } +} diff --git a/nlog/writer_test.go b/nlog/writer_test.go new file mode 100644 index 0000000..51b1b4c --- /dev/null +++ b/nlog/writer_test.go @@ -0,0 +1,221 @@ +package nlog + +import ( + "bytes" + "encoding/json" + "errors" + "log" + "sync/atomic" + "testing" + + "git.noahlan.cn/noahlan/ntool/ntest/assert" +) + +func TestNewWriter(t *testing.T) { + const literal = "foo bar" + var buf bytes.Buffer + w := NewWriter(&buf) + w.Info(literal) + assert.Contains(t, buf.String(), literal) + buf.Reset() + w.Debug(literal) + assert.Contains(t, buf.String(), literal) +} + +func TestConsoleWriter(t *testing.T) { + var buf bytes.Buffer + w := newConsoleWriter() + lw := newLogWriter(log.New(&buf, "", 0)) + w.(*concreteWriter).errorLog = lw + w.Alert("foo bar 1") + var val mockedEntry + if err := json.Unmarshal(buf.Bytes(), &val); err != nil { + t.Fatal(err) + } + assert.Equal(t, levelAlert, val.Level) + assert.Equal(t, "foo bar 1", val.Content) + + buf.Reset() + w.(*concreteWriter).errorLog = lw + w.Error("foo bar 2") + if err := json.Unmarshal(buf.Bytes(), &val); err != nil { + t.Fatal(err) + } + assert.Equal(t, levelError, val.Level) + assert.Equal(t, "foo bar 2", val.Content) + + buf.Reset() + w.(*concreteWriter).infoLog = lw + w.Info("foo bar 3") + if err := json.Unmarshal(buf.Bytes(), &val); err != nil { + t.Fatal(err) + } + assert.Equal(t, levelInfo, val.Level) + assert.Equal(t, "foo bar 3", val.Content) + + buf.Reset() + w.(*concreteWriter).severeLog = lw + w.Severe("foo bar 4") + if err := json.Unmarshal(buf.Bytes(), &val); err != nil { + t.Fatal(err) + } + assert.Equal(t, levelFatal, val.Level) + assert.Equal(t, "foo bar 4", val.Content) + + buf.Reset() + w.(*concreteWriter).slowLog = lw + w.Slow("foo bar 5") + if err := json.Unmarshal(buf.Bytes(), &val); err != nil { + t.Fatal(err) + } + assert.Equal(t, levelSlow, val.Level) + assert.Equal(t, "foo bar 5", val.Content) + + buf.Reset() + w.(*concreteWriter).statLog = lw + w.Stat("foo bar 6") + if err := json.Unmarshal(buf.Bytes(), &val); err != nil { + t.Fatal(err) + } + assert.Equal(t, levelStat, val.Level) + assert.Equal(t, "foo bar 6", val.Content) + + w.(*concreteWriter).infoLog = hardToCloseWriter{} + assert.NotNil(t, w.Close()) + w.(*concreteWriter).infoLog = easyToCloseWriter{} + w.(*concreteWriter).errorLog = hardToCloseWriter{} + assert.NotNil(t, w.Close()) + w.(*concreteWriter).errorLog = easyToCloseWriter{} + w.(*concreteWriter).severeLog = hardToCloseWriter{} + assert.NotNil(t, w.Close()) + w.(*concreteWriter).severeLog = easyToCloseWriter{} + w.(*concreteWriter).slowLog = hardToCloseWriter{} + assert.NotNil(t, w.Close()) + w.(*concreteWriter).slowLog = easyToCloseWriter{} + w.(*concreteWriter).statLog = hardToCloseWriter{} + assert.NotNil(t, w.Close()) + w.(*concreteWriter).statLog = easyToCloseWriter{} +} + +func TestNopWriter(t *testing.T) { + assert.NotPanics(t, func() { + var w nopWriter + w.Alert("foo") + w.Debug("foo") + w.Error("foo") + w.Info("foo") + w.Severe("foo") + w.Stack("foo") + w.Stat("foo") + w.Slow("foo") + _ = w.Close() + }) +} + +func TestWriteJson(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) + writeJson(nil, "foo") + assert.Contains(t, buf.String(), "foo") + buf.Reset() + writeJson(nil, make(chan int)) + assert.Contains(t, buf.String(), "unsupported type") +} + +func TestWritePlainAny(t *testing.T) { + var buf bytes.Buffer + log.SetOutput(&buf) + writePlainAny(nil, levelInfo, "foo") + assert.Contains(t, buf.String(), "foo") + + buf.Reset() + writePlainAny(nil, levelDebug, make(chan int)) + assert.Contains(t, buf.String(), "unsupported type") + writePlainAny(nil, levelDebug, 100) + assert.Contains(t, buf.String(), "100") + + buf.Reset() + writePlainAny(nil, levelError, make(chan int)) + assert.Contains(t, buf.String(), "unsupported type") + writePlainAny(nil, levelSlow, 100) + assert.Contains(t, buf.String(), "100") + + buf.Reset() + writePlainAny(hardToWriteWriter{}, levelStat, 100) + assert.Contains(t, buf.String(), "write error") + + buf.Reset() + writePlainAny(hardToWriteWriter{}, levelSevere, "foo") + assert.Contains(t, buf.String(), "write error") + + buf.Reset() + writePlainAny(hardToWriteWriter{}, levelAlert, "foo") + assert.Contains(t, buf.String(), "write error") + + buf.Reset() + writePlainAny(hardToWriteWriter{}, levelFatal, "foo") + assert.Contains(t, buf.String(), "write error") + +} + +func TestLogWithLimitContentLength(t *testing.T) { + maxLen := atomic.LoadUint32(&maxContentLength) + atomic.StoreUint32(&maxContentLength, 10) + + t.Cleanup(func() { + atomic.StoreUint32(&maxContentLength, maxLen) + }) + + t.Run("alert", func(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + w.Info("1234567890") + var v1 mockedEntry + if err := json.Unmarshal(buf.Bytes(), &v1); err != nil { + t.Fatal(err) + } + assert.Equal(t, "1234567890", v1.Content) + assert.False(t, v1.Truncated) + + buf.Reset() + var v2 mockedEntry + w.Info("12345678901") + if err := json.Unmarshal(buf.Bytes(), &v2); err != nil { + t.Fatal(err) + } + assert.Equal(t, "1234567890", v2.Content) + assert.True(t, v2.Truncated) + }) +} + +type mockedEntry struct { + Level string `json:"level"` + Content string `json:"content"` + Truncated bool `json:"truncated"` +} + +type easyToCloseWriter struct{} + +func (h easyToCloseWriter) Write(_ []byte) (_ int, _ error) { + return +} + +func (h easyToCloseWriter) Close() error { + return nil +} + +type hardToCloseWriter struct{} + +func (h hardToCloseWriter) Write(_ []byte) (_ int, _ error) { + return +} + +func (h hardToCloseWriter) Close() error { + return errors.New("close error") +} + +type hardToWriteWriter struct{} + +func (h hardToWriteWriter) Write(_ []byte) (_ int, _ error) { + return 0, errors.New("write error") +} diff --git a/nmap/check.go b/nmap/check.go new file mode 100644 index 0000000..925dc97 --- /dev/null +++ b/nmap/check.go @@ -0,0 +1,63 @@ +package nmap + +import ( + "git.noahlan.cn/noahlan/ntool/nreflect" + "reflect" +) + +// HasKey check of the given map. +func HasKey(mp, key any) (ok bool) { + rftVal := reflect.Indirect(reflect.ValueOf(mp)) + if rftVal.Kind() != reflect.Map { + return + } + + for _, keyRv := range rftVal.MapKeys() { + if nreflect.IsEqual(keyRv.Interface(), key) { + return true + } + } + return +} + +// HasOneKey check of the given map. return the first exist key +func HasOneKey(mp any, keys ...any) (ok bool, key any) { + rftVal := reflect.Indirect(reflect.ValueOf(mp)) + if rftVal.Kind() != reflect.Map { + return + } + + for _, key = range keys { + for _, keyRv := range rftVal.MapKeys() { + if nreflect.IsEqual(keyRv.Interface(), key) { + return true, key + } + } + } + + return false, nil +} + +// HasAllKeys check of the given map. return the first not exist key +func HasAllKeys(mp any, keys ...any) (ok bool, noKey any) { + rftVal := reflect.Indirect(reflect.ValueOf(mp)) + if rftVal.Kind() != reflect.Map { + return + } + + for _, key := range keys { + var exist bool + for _, keyRv := range rftVal.MapKeys() { + if nreflect.IsEqual(keyRv.Interface(), key) { + exist = true + break + } + } + + if !exist { + return false, key + } + } + + return true, nil +} diff --git a/nmap/check_test.go b/nmap/check_test.go new file mode 100644 index 0000000..ff5f3eb --- /dev/null +++ b/nmap/check_test.go @@ -0,0 +1,33 @@ +package nmap_test + +import ( + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestHasKey(t *testing.T) { + var mp any = map[string]string{"key0": "val0"} + + assert.True(t, nmap.HasKey(mp, "key0")) + assert.False(t, nmap.HasKey(mp, "not-exist")) + assert.False(t, nmap.HasKey("abc", "not-exist")) +} + +func TestHasAllKeys(t *testing.T) { + var mp any = map[string]string{"key0": "val0", "key1": "def"} + ok, noKey := nmap.HasAllKeys(mp, "key0") + assert.True(t, ok) + assert.Nil(t, noKey) + + ok, noKey = nmap.HasAllKeys(mp, "key0", "key1") + assert.True(t, ok) + assert.Nil(t, noKey) + + ok, noKey = nmap.HasAllKeys(mp, "key0", "not-exist") + assert.False(t, ok) + assert.Eq(t, "not-exist", noKey) + + ok, _ = nmap.HasAllKeys(mp, "invalid-map", "not-exist") + assert.False(t, ok) +} diff --git a/nmap/convert.go b/nmap/convert.go new file mode 100644 index 0000000..159f204 --- /dev/null +++ b/nmap/convert.go @@ -0,0 +1,142 @@ +package nmap + +import ( + "errors" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nreflect" + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" + "strings" +) + +// KeyToLower convert keys to lower case. +func KeyToLower(src map[string]string) map[string]string { + newMp := make(map[string]string, len(src)) + for k, v := range src { + k = strings.ToLower(k) + newMp[k] = v + } + return newMp +} + +// ToStringMap convert map[string]any to map[string]string +func ToStringMap(src map[string]any) map[string]string { + strMp := make(map[string]string, len(src)) + for k, v := range src { + strMp[k] = nstr.SafeString(v) + } + return strMp +} + +// CombineToSMap combine two string-slice to SMap(map[string]string) +func CombineToSMap(keys, values []string) SMap { + return narr.CombineToSMap(keys, values) +} + +// CombineToMap combine two any slice to map[K]V. alias of arrutil.CombineToMap +func CombineToMap[K ndef.SortedType, V any](keys []K, values []V) map[K]V { + return narr.CombineToMap(keys, values) +} + +// ToAnyMap convert map[TYPE1]TYPE2 to map[string]any +func ToAnyMap(mp any) map[string]any { + amp, _ := TryAnyMap(mp) + return amp +} + +// TryAnyMap convert map[TYPE1]TYPE2 to map[string]any +func TryAnyMap(mp any) (map[string]any, error) { + if aMp, ok := mp.(map[string]any); ok { + return aMp, nil + } + + rv := reflect.Indirect(reflect.ValueOf(mp)) + if rv.Kind() != reflect.Map { + return nil, errors.New("input is not a map value") + } + + anyMp := make(map[string]any, rv.Len()) + for _, key := range rv.MapKeys() { + anyMp[key.String()] = rv.MapIndex(key).Interface() + } + return anyMp, nil +} + +// HTTPQueryString convert map[string]any data to http query string. +func HTTPQueryString(data map[string]any) string { + ss := make([]string, 0, len(data)) + for k, v := range data { + ss = append(ss, k+"="+nstr.SafeString(v)) + } + + return strings.Join(ss, "&") +} + +// ToString simple and quickly convert map[string]any to string. +func ToString(mp map[string]any) string { + if mp == nil { + return "" + } + if len(mp) == 0 { + return "{}" + } + + buf := make([]byte, 0, len(mp)*16) + buf = append(buf, '{') + + for k, val := range mp { + buf = append(buf, k...) + buf = append(buf, ':') + + str := nstr.SafeString(val) + buf = append(buf, str...) + buf = append(buf, ',', ' ') + } + + // remove last ', ' + buf = append(buf[:len(buf)-2], '}') + return nstr.Byte2str(buf) +} + +// ToString2 simple and quickly convert a map to string. +func ToString2(mp any) string { + return NewFormatter(mp).Format() +} + +// FormatIndent format map data to string with newline and indent. +func FormatIndent(mp any, indent string) string { + return NewFormatter(mp).WithIndent(indent).Format() +} + +/************************************************************* + * Flat convert tree map to flatten key-value map. + *************************************************************/ + +// Flatten convert tree map to flat key-value map. +// +// Examples: +// +// {"top": {"sub": "value", "sub2": "value2"} } +// -> +// {"top.sub": "value", "top.sub2": "value2" } +func Flatten(mp map[string]any) map[string]any { + if mp == nil { + return nil + } + + flatMp := make(map[string]any, len(mp)*2) + nreflect.FlatMap(reflect.ValueOf(mp), func(path string, val reflect.Value) { + flatMp[path] = val.Interface() + }) + + return flatMp +} + +// FlatWithFunc flat a tree-map with custom collect handle func +func FlatWithFunc(mp map[string]any, fn nreflect.FlatFunc) { + if mp == nil || fn == nil { + return + } + nreflect.FlatMap(reflect.ValueOf(mp), fn) +} diff --git a/nmap/convert_test.go b/nmap/convert_test.go new file mode 100644 index 0000000..10e47a6 --- /dev/null +++ b/nmap/convert_test.go @@ -0,0 +1,100 @@ +package nmap_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestKeyToLower(t *testing.T) { + src := map[string]string{"A": "v0"} + ret := nmap.KeyToLower(src) + + assert.Contains(t, ret, "a") + assert.NotContains(t, ret, "A") +} + +func TestToStringMap(t *testing.T) { + src := map[string]any{"a": "v0", "b": 23} + ret := nmap.ToStringMap(src) + + assert.Eq(t, ret["a"], "v0") + assert.Eq(t, ret["b"], "23") + + keys := []string{"key0", "key1"} + + mp := nmap.CombineToSMap(keys, []string{"val0", "val1"}) + assert.Len(t, mp, 2) + assert.Eq(t, "val0", mp.Str("key0")) +} + +func TestToAnyMap(t *testing.T) { + src := map[string]string{"a": "v0", "b": "23"} + + mp := nmap.ToAnyMap(src) + assert.Len(t, mp, 2) + assert.Eq(t, "v0", mp["a"]) + + src1 := map[string]any{"a": "v0", "b": "23"} + mp = nmap.ToAnyMap(src1) + assert.Len(t, mp, 2) + assert.Eq(t, "v0", mp["a"]) + + _, err := nmap.TryAnyMap(123) + assert.Err(t, err) +} + +func TestHTTPQueryString(t *testing.T) { + src := map[string]any{"a": "v0", "b": 23} + str := nmap.HTTPQueryString(src) + + fmt.Println(str) + assert.Contains(t, str, "b=23") + assert.Contains(t, str, "a=v0") +} + +func TestToString2(t *testing.T) { + src := map[string]any{"a": "v0", "b": 23} + + s := nmap.ToString2(src) + assert.Contains(t, s, "b:23") + assert.Contains(t, s, "a:v0") +} + +func TestToString(t *testing.T) { + src := map[string]any{"a": "v0", "b": 23} + + s := nmap.ToString(src) + //dump.P(s) + assert.Contains(t, s, "b:23") + assert.Contains(t, s, "a:v0") + + s = nmap.ToString(nil) + assert.Eq(t, "", s) + + s = nmap.ToString(map[string]any{}) + assert.Eq(t, "{}", s) + + s = nmap.ToString(map[string]any{"": nil}) + assert.Eq(t, "{:}", s) +} + +func TestFlatten(t *testing.T) { + data := map[string]any{ + "name": "inhere", + "age": 234, + "top": map[string]any{ + "sub0": "val0", + "sub1": []string{"val1-0", "val1-1"}, + }, + } + + mp := nmap.Flatten(data) + assert.ContainsKeys(t, mp, []string{"age", "name", "top.sub0", "top.sub1[0]", "top.sub1[1]"}) + assert.Nil(t, nmap.Flatten(nil)) + + assert.NotPanics(t, func() { + nmap.FlatWithFunc(nil, nil) + }) +} diff --git a/nmap/data.go b/nmap/data.go new file mode 100644 index 0000000..d26f445 --- /dev/null +++ b/nmap/data.go @@ -0,0 +1,259 @@ +package nmap + +import ( + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nstr" + "strings" +) + +// Data an map data type +type Data map[string]any + +// Map alias of Data +type Map = Data + +// Has value on the data map +func (d Data) Has(key string) bool { + _, ok := d.GetByPath(key) + return ok +} + +// IsEmtpy if the data map +func (d Data) IsEmtpy() bool { + return len(d) == 0 +} + +// Value get from the data map +func (d Data) Value(key string) (any, bool) { + val, ok := d.GetByPath(key) + return val, ok +} + +// Get value from the data map. +// Supports dot syntax to get deep values. eg: top.sub +func (d Data) Get(key string) any { + if val, ok := d.GetByPath(key); ok { + return val + } + return nil +} + +// GetByPath get value from the data map by path. eg: top.sub +// Supports dot syntax to get deep values. +func (d Data) GetByPath(path string) (any, bool) { + if val, ok := d[path]; ok { + return val, true + } + + // is key path. + if strings.ContainsRune(path, '.') { + val, ok := GetByPath(path, d) + if ok { + return val, true + } + } + return nil, false +} + +// Set value to the data map +func (d Data) Set(key string, val any) { + d[key] = val +} + +// SetByPath sets a value in the map. +// Supports dot syntax to set deep values. +// +// For example: +// +// d.SetByPath("name.first", "Mat") +func (d Data) SetByPath(path string, value any) error { + if path == "" { + return nil + } + return d.SetByKeys(strings.Split(path, KeySepStr), value) +} + +// SetByKeys sets a value in the map by path keys. +// Supports dot syntax to set deep values. +// +// For example: +// +// d.SetByKeys([]string{"name", "first"}, "Mat") +func (d Data) SetByKeys(keys []string, value any) error { + kln := len(keys) + if kln == 0 { + return nil + } + + // special handle d is empty. + if len(d) == 0 { + if kln == 1 { + d.Set(keys[0], value) + } else { + d.Set(keys[0], MakeByKeys(keys[1:], value)) + } + return nil + } + + return SetByKeys((*map[string]any)(&d), keys, value) + // It's ok, but use `func (d *Data)` + // return SetByKeys((*map[string]any)(d), keys, value) +} + +// Default get value from the data map with default value +func (d Data) Default(key string, def any) any { + if val, ok := d.GetByPath(key); ok { + return val + } + return def +} + +// Int value get +func (d Data) Int(key string) int { + if val, ok := d.GetByPath(key); ok { + return nmath.QuietInt(val) + } + return 0 +} + +// Int64 value get +func (d Data) Int64(key string) int64 { + if val, ok := d.GetByPath(key); ok { + return nmath.QuietInt64(val) + } + return 0 +} + +// Uint value get +func (d Data) Uint(key string) uint64 { + if val, ok := d.GetByPath(key); ok { + return nmath.QuietUint(val) + } + return 0 +} + +// Str value get by key +func (d Data) Str(key string) string { + if val, ok := d.GetByPath(key); ok { + return nstr.SafeString(val) + } + return "" +} + +// Bool value get +func (d Data) Bool(key string) bool { + val, ok := d.GetByPath(key) + if !ok { + return false + } + + switch tv := val.(type) { + case string: + return nstr.QuietBool(tv) + case bool: + return tv + default: + return false + } +} + +// Strings get []string value +func (d Data) Strings(key string) []string { + val, ok := d.GetByPath(key) + if !ok { + return nil + } + + switch typVal := val.(type) { + case string: + return []string{typVal} + case []string: + return typVal + case []any: + return narr.SliceToStrings(typVal) + default: + return nil + } +} + +// StrSplit get strings by split key value +func (d Data) StrSplit(key, sep string) []string { + if val, ok := d.GetByPath(key); ok { + return strings.Split(nstr.SafeString(val), sep) + } + return nil +} + +// StringsByStr value get by key +func (d Data) StringsByStr(key string) []string { + if val, ok := d.GetByPath(key); ok { + return strings.Split(nstr.SafeString(val), ",") + } + return nil +} + +// StrMap get map[string]string value +func (d Data) StrMap(key string) map[string]string { + return d.StringMap(key) +} + +// StringMap get map[string]string value +func (d Data) StringMap(key string) map[string]string { + val, ok := d.GetByPath(key) + if !ok { + return nil + } + + switch tv := val.(type) { + case map[string]string: + return tv + case map[string]any: + return ToStringMap(tv) + default: + return nil + } +} + +// Sub get sub value as new Data +func (d Data) Sub(key string) Data { + if val, ok := d.GetByPath(key); ok { + if sub, ok := val.(map[string]any); ok { + return sub + } + } + return nil +} + +// Keys of the data map +func (d Data) Keys() []string { + keys := make([]string, 0, len(d)) + for k := range d { + keys = append(keys, k) + } + return keys +} + +// ToStringMap convert to map[string]string +func (d Data) ToStringMap() map[string]string { + return ToStringMap(d) +} + +// String data to string +func (d Data) String() string { + return ToString(d) +} + +// Load other data to current data map +func (d Data) Load(sub map[string]any) { + for name, val := range sub { + d[name] = val + } +} + +// LoadSMap to data +func (d Data) LoadSMap(smp map[string]string) { + for name, val := range smp { + d[name] = val + } +} diff --git a/nmap/data_test.go b/nmap/data_test.go new file mode 100644 index 0000000..ada2141 --- /dev/null +++ b/nmap/data_test.go @@ -0,0 +1,181 @@ +package nmap_test + +import ( + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "reflect" + "testing" +) + +func TestData_usage(t *testing.T) { + mp := nmap.Data{ + "k1": 23, + "k2": "ab", + "k3": "true", + "k4": false, + "k5": map[string]string{"a": "b"}, + "anyMp": map[string]any{"b": 23}, + } + + assert.True(t, mp.Has("k1")) + assert.True(t, mp.Bool("k3")) + assert.False(t, mp.Bool("k4")) + assert.False(t, mp.IsEmtpy()) + assert.Eq(t, 23, mp.Get("k1")) + assert.Eq(t, "b", mp.Get("k5.a")) + assert.Eq(t, 23, mp.Get("anyMp.b")) + + // int + assert.Eq(t, 23, mp.Int("k1")) + assert.Eq(t, int64(23), mp.Int64("k1")) + + // str + assert.Eq(t, "23", mp.Str("k1")) + assert.Eq(t, "ab", mp.Str("k2")) + + // set + mp.Set("new", "val1") + assert.Eq(t, "val1", mp.Str("new")) + + val, ok := mp.Value("new") + assert.True(t, ok) + assert.Eq(t, "val1", val) + + // not exists + assert.False(t, mp.Bool("notExists")) + assert.Eq(t, 0, mp.Int("notExists")) + assert.Eq(t, int64(0), mp.Int64("notExists")) + assert.Eq(t, "", mp.Str("notExists")) + + // default + assert.Eq(t, 23, mp.Default("k1", 10)) + assert.Eq(t, 10, mp.Default("notExists", 10)) + + assert.Nil(t, mp.StringMap("notExists")) + assert.Eq(t, map[string]string{"a": "b"}, mp.StringMap("k5")) + assert.Eq(t, map[string]string{"b": "23"}, mp.StringMap("anyMp")) +} + +func TestData_SetByPath(t *testing.T) { + mp := nmap.Data{ + "k2": "ab", + "k5": map[string]any{"a": "v0"}, + } + assert.Nil(t, mp.Get("k5.b")) + assert.Len(t, mp.Keys(), 2) + assert.NotEmpty(t, mp.ToStringMap()) + + err := mp.SetByPath("k5.b", "v2") + assert.NoErr(t, err) + // dump.P(mp) + assert.Eq(t, "v2", mp.Get("k5.b")) + + mp.Load(map[string]any{"k2": "val2", "k3": "val3"}) + assert.Eq(t, "val2", mp.Str("k2")) + assert.Eq(t, "val3", mp.Str("k3")) + + // sub + assert.Nil(t, mp.Sub("not-exists")) + sub := mp.Sub("k5") + assert.Eq(t, "v0", sub.Get("a")) + assert.Eq(t, "v2", sub.Get("b")) +} + +func TestData_SetByPath_case2(t *testing.T) { + mp := nmap.Data{} + assert.Eq(t, 0, len(mp)) + + err := mp.SetByPath("top2.inline.list.ids", []int{234, 345, 456}) + assert.NoErr(t, err) + assert.Eq(t, []int{234, 345, 456}, mp.Get("top2.inline.list.ids")) + + err = mp.SetByPath("top2.sub.var-refer", "val1") + assert.NoErr(t, err) + assert.Eq(t, "val1", mp.Get("top2.sub.var-refer")) + + err = mp.SetByPath("top2.sub.key2-other", "val2") + assert.NoErr(t, err) + assert.Eq(t, "val2", mp.Get("top2.sub.key2-other")) + // dump.P(mp) +} + +func TestData_SetByPath_case3(t *testing.T) { + mp := nmap.Data{} + assert.Eq(t, 0, len(mp)) + + err := mp.SetByPath("top.sub.key3", "false") + assert.NoErr(t, err) + assert.Eq(t, "false", mp.Get("top.sub.key3")) + assert.False(t, mp.Bool("top.sub.key3")) + + err = mp.SetByPath("top.sub.key4[0]", "abc") + assert.NoErr(t, err) + + err = mp.SetByPath("top.sub.key4[1]", "def") + assert.NoErr(t, err) + sli := mp.Get("top.sub.key4") + assert.IsKind(t, reflect.Slice, sli) + assert.Len(t, sli, 2) + // dump.P(mp, sli) +} + +// top.sub.key5[0].f1 = ab +// top.sub.key5[1].f2 = de +func TestData_SetByPath_case4(t *testing.T) { + mp := nmap.Data{} + assert.Eq(t, 0, len(mp)) + + err := mp.SetByPath("top.sub.key3", "false") + assert.NoErr(t, err) + assert.Eq(t, "false", mp.Get("top.sub.key3")) + + err = mp.SetByPath("top.sub.key5[0].f1", "val1") + assert.NoErr(t, err) + // dump.P(mp) + + err = mp.SetByPath("top.sub.key5[1].f2", "val2") + assert.NoErr(t, err) + //dump.P(mp) + sli := mp.Get("top.sub.key5") + assert.IsKind(t, reflect.Slice, sli) + assert.Len(t, sli, 2) +} + +func TestData_SetByKeys_emptyData(t *testing.T) { + // one level + mp := make(nmap.Data) + err := mp.SetByKeys([]string{"k3"}, "v3") + assert.NoErr(t, err) + //dump.P(mp) + + assert.Eq(t, "v3", mp.Str("k3")) + + // two level + mp1 := make(nmap.Data) + err = mp1.SetByKeys([]string{"k5", "b"}, "v2") + assert.NoErr(t, err) + //dump.P(mp1) + + assert.Eq(t, "v2", mp1.Get("k5.b")) +} + +func TestData_SetByKeys(t *testing.T) { + mp := nmap.Data{ + "k2": "ab", + "k5": map[string]any{"a": "v0"}, + } + assert.Nil(t, mp.Get("k3")) + assert.Nil(t, mp.Get("k5.b")) + + assert.NoErr(t, mp.SetByKeys([]string{}, "v3")) + + err := mp.SetByKeys([]string{"k3"}, "v3") + assert.NoErr(t, err) + assert.Eq(t, "v3", mp.Str("k3")) + + err = mp.SetByKeys([]string{"k5", "b"}, "v2") + assert.NoErr(t, err) + + // dump.P(mp) + assert.Eq(t, "v2", mp.Get("k5.b")) +} diff --git a/nmap/errors.go b/nmap/errors.go new file mode 100644 index 0000000..fdfd093 --- /dev/null +++ b/nmap/errors.go @@ -0,0 +1,39 @@ +package nmap + +import "strings" + +// ErrMap multi error map +type ErrMap map[string]error + +// Error string +func (e ErrMap) Error() string { + var sb strings.Builder + for name, err := range e { + sb.WriteString(name) + sb.WriteByte(':') + sb.WriteString(err.Error()) + sb.WriteByte('\n') + } + return sb.String() +} + +// ErrorOrNil error +func (e ErrMap) ErrorOrNil() error { + if len(e) == 0 { + return nil + } + return e +} + +// IsEmpty error +func (e ErrMap) IsEmpty() bool { + return len(e) == 0 +} + +// One error +func (e ErrMap) One() error { + for _, err := range e { + return err + } + return nil +} diff --git a/nmap/format.go b/nmap/format.go new file mode 100644 index 0000000..86de1fa --- /dev/null +++ b/nmap/format.go @@ -0,0 +1,124 @@ +package nmap + +import ( + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nstr" + "io" + "reflect" +) + +// MapFormatter struct +type MapFormatter struct { + ndef.BaseFormatter + // Prefix string for each element + Prefix string + // Indent string for each element + Indent string + // ClosePrefix string for last "}" + ClosePrefix string + // AfterReset after reset on call Format(). + // AfterReset bool +} + +// NewFormatter instance +func NewFormatter(mp any) *MapFormatter { + f := &MapFormatter{} + f.Src = mp + + return f +} + +// WithFn for config self +func (f *MapFormatter) WithFn(fn func(f *MapFormatter)) *MapFormatter { + fn(f) + return f +} + +// WithIndent string +func (f *MapFormatter) WithIndent(indent string) *MapFormatter { + f.Indent = indent + return f +} + +// FormatTo to custom buffer +func (f *MapFormatter) FormatTo(w io.Writer) { + f.SetOutput(w) + f.doFormat() +} + +// Format to string +func (f *MapFormatter) String() string { + return f.Format() +} + +// Format to string +func (f *MapFormatter) Format() string { + f.doFormat() + return f.BsWriter().String() +} + +// Format map data to string. +// +//goland:noinspection GoUnhandledErrorResult +func (f *MapFormatter) doFormat() { + if f.Src == nil { + return + } + + rv, ok := f.Src.(reflect.Value) + if !ok { + rv = reflect.ValueOf(f.Src) + } + + rv = reflect.Indirect(rv) + if rv.Kind() != reflect.Map { + return + } + + buf := f.BsWriter() + ln := rv.Len() + if ln == 0 { + buf.WriteString("{}") + return + } + + // buf.Grow(ln * 16) + buf.WriteByte('{') + + indentLn := len(f.Indent) + if indentLn > 0 { + buf.WriteByte('\n') + } + + for i, key := range rv.MapKeys() { + kStr := nstr.SafeString(key.Interface()) + if indentLn > 0 { + buf.WriteString(f.Indent) + } + + buf.WriteString(kStr) + buf.WriteByte(':') + + vStr := nstr.SafeString(rv.MapIndex(key).Interface()) + buf.WriteString(vStr) + if i < ln-1 { + buf.WriteByte(',') + + // no indent, with space + if indentLn == 0 { + buf.WriteByte(' ') + } + } + + // with newline + if indentLn > 0 { + buf.WriteByte('\n') + } + } + + if f.ClosePrefix != "" { + buf.WriteString(f.ClosePrefix) + } + + buf.WriteByte('}') +} diff --git a/nmap/format_test.go b/nmap/format_test.go new file mode 100644 index 0000000..431cce9 --- /dev/null +++ b/nmap/format_test.go @@ -0,0 +1,48 @@ +package nmap_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestNewFormatter(t *testing.T) { + mp := map[string]any{"a": "v0", "b": 23} + + mf := nmap.NewFormatter(mp) + assert.Contains(t, mf.String(), "b:23") + + buf := ntest.NewTestWriter() + mf = nmap.NewFormatter(mp).WithFn(func(f *nmap.MapFormatter) { + f.Indent = " " + }) + mf.FormatTo(buf) + assert.Contains(t, buf.String(), "\n ") + fmt.Println(buf.String()) + + s := nmap.FormatIndent(mp, " ") + fmt.Println(s) + assert.Contains(t, s, "\n ") + + s = nmap.FormatIndent(mp, "") + fmt.Println(s) + assert.NotContains(t, s, "\n ") +} + +func TestFormatIndent_mlevel(t *testing.T) { + mp := map[string]any{"a": "v0", "b": 23} + + mp["subs"] = map[string]string{ + "sub_k1": "sub val1", + "sub_k2": "sub val2", + } + + s := nmap.FormatIndent(mp, "") + fmt.Println(s) + assert.NotContains(t, s, "\n ") + + s = nmap.FormatIndent(mp, " ") + fmt.Println(s) +} diff --git a/nmap/get.go b/nmap/get.go new file mode 100644 index 0000000..0392e06 --- /dev/null +++ b/nmap/get.go @@ -0,0 +1,169 @@ +package nmap + +import ( + "reflect" + "strconv" + "strings" +) + +// some consts for separators +const ( + Wildcard = "*" + PathSep = "." +) + +// DeepGet value by key path. eg "top" "top.sub" +func DeepGet(mp map[string]any, path string) (val any) { + val, _ = GetByPath(path, mp) + return +} + +// QuietGet value by key path. eg "top" "top.sub" +func QuietGet(mp map[string]any, path string) (val any) { + val, _ = GetByPath(path, mp) + return +} + +// GetByPath get value by key path from a map(map[string]any). eg "top" "top.sub" +func GetByPath(path string, mp map[string]any) (val any, ok bool) { + if val, ok := mp[path]; ok { + return val, true + } + + // no sub key + if len(mp) == 0 || strings.IndexByte(path, '.') < 1 { + return nil, false + } + + // has sub key. eg. "top.sub" + keys := strings.Split(path, ".") + return GetByPathKeys(mp, keys) +} + +// GetByPathKeys get value by path keys from a map(map[string]any). eg "top" "top.sub" +// +// Example: +// +// mp := map[string]any{ +// "top": map[string]any{ +// "sub": "value", +// }, +// } +// val, ok := GetByPathKeys(mp, []string{"top", "sub"}) // return "value", true +func GetByPathKeys(mp map[string]any, keys []string) (val any, ok bool) { + kl := len(keys) + if kl == 0 { + return mp, true + } + + // find top item data use top key + var item any + + topK := keys[0] + if item, ok = mp[topK]; !ok { + return + } + + // find sub item data use sub key + for i, k := range keys[1:] { + switch tData := item.(type) { + case map[string]string: // is string map + if item, ok = tData[k]; !ok { + return + } + case map[string]any: // is map(decode from toml/json/yaml) + if item, ok = tData[k]; !ok { + return + } + case map[any]any: // is map(decode from yaml.v2) + if item, ok = tData[k]; !ok { + return + } + case []map[string]any: // is an any-map slice + if k == Wildcard { + if kl == i+2 { + return tData, true + } + + sl := make([]any, 0, len(tData)) + for _, v := range tData { + if val, ok = GetByPathKeys(v, keys[i+2:]); ok { + sl = append(sl, val) + } + } + return sl, true + } + + // k is index number + idx, err := strconv.Atoi(k) + if err != nil { + return nil, false + } + + if idx >= len(tData) { + return nil, false + } + item = tData[idx] + default: + rv := reflect.ValueOf(tData) + // check is slice + if rv.Kind() == reflect.Slice { + i, err := strconv.Atoi(k) + if err != nil { + return nil, false + } + if i >= rv.Len() { + return nil, false + } + + item = rv.Index(i).Interface() + continue + } + + // as error + return nil, false + } + } + + return item, true +} + +// Keys get all keys of the given map. +func Keys(mp any) (keys []string) { + rftVal := reflect.Indirect(reflect.ValueOf(mp)) + if rftVal.Kind() != reflect.Map { + return + } + + keys = make([]string, 0, rftVal.Len()) + for _, key := range rftVal.MapKeys() { + keys = append(keys, key.String()) + } + return +} + +// Values get all values from the given map. +func Values(mp any) (values []any) { + rv := reflect.Indirect(reflect.ValueOf(mp)) + if rv.Kind() != reflect.Map { + return + } + + values = make([]any, 0, rv.Len()) + for _, key := range rv.MapKeys() { + values = append(values, rv.MapIndex(key).Interface()) + } + return +} + +// EachAnyMap iterates the given map and calls the given function for each item. +func EachAnyMap(mp any, fn func(key string, val any)) { + rv := reflect.Indirect(reflect.ValueOf(mp)) + if rv.Kind() != reflect.Map { + panic("not a map value") + } + + for _, key := range rv.MapKeys() { + fn(key.String(), rv.MapIndex(key).Interface()) + } +} diff --git a/nmap/get_test.go b/nmap/get_test.go new file mode 100644 index 0000000..a6c639b --- /dev/null +++ b/nmap/get_test.go @@ -0,0 +1,211 @@ +package nmap_test + +import ( + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestGetByPath(t *testing.T) { + mp := map[string]any{ + "key0": "val0", + "key1": map[string]string{"sk0": "sv0"}, + "key2": []string{"sv1", "sv2"}, + "key3": map[string]any{"sk1": "sv1"}, + "key4": []int{1, 2}, + "key5": []any{1, "2", true}, + "mlMp": []map[string]any{ + { + "code": "001", + "names": []string{"John", "abc"}, + }, + { + "code": "002", + "names": []string{"Tom", "def"}, + }, + }, + } + + tests := []struct { + path string + want any + ok bool + }{ + {"key0", "val0", true}, + {"key1.sk0", "sv0", true}, + {"key3.sk1", "sv1", true}, + // not exists + {"not-exits", nil, false}, + {"key2.not-exits", nil, false}, + {"not-exits.subkey", nil, false}, + // slices behaviour + {"key2", mp["key2"], true}, + {"key2.0", "sv1", true}, + {"key2.1", "sv2", true}, + {"key4.0", 1, true}, + {"key4.1", 2, true}, + {"key5.0", 1, true}, + {"key5.1", "2", true}, + {"key5.2", true, true}, + // out of bound + {"key4.3", nil, false}, + // deep sub map + {"mlMp.*.code", []any{"001", "002"}, true}, + {"mlMp.*.names", []any{ + []string{"John", "abc"}, + []string{"Tom", "def"}, + }, true}, + {"mlMp.*.names.1", []any{"abc", "def"}, true}, + } + + for _, tt := range tests { + v, ok := nmap.GetByPath(tt.path, mp) + assert.Eq(t, tt.ok, ok, tt.path) + assert.Eq(t, tt.want, v, tt.path) + } + + // v, ok := nmap.GetByPath("mlMp.*.names.1", mp) + // assert.True(t, ok) + // assert.Eq(t, []any{"abc", "def"}, v) +} + +var mlMp = map[string]any{ + "names": []string{"John", "Jane", "abc"}, + "coding": []map[string]any{ + { + "details": map[string]any{ + "em": map[string]any{ + "code": "001-1", + "encounter_uid": "1-1", + "billing_provider": "Test provider 01-1", + "resident_provider": "Test Resident Provider-1", + }, + }, + }, + { + "details": map[string]any{ + "em": map[string]any{ + "code": "001", + "encounter_uid": "1", + "billing_provider": "Test provider 01", + "resident_provider": "Test Resident Provider", + }, + "cpt": []map[string]any{ + { + "code": "001", + "encounter_uid": "2", + "work_item_uid": "3", + "billing_provider": "Test provider 001", + "resident_provider": "Test Resident Provider", + }, + { + "code": "OBS01", + "encounter_uid": "3", + "work_item_uid": "4", + "billing_provider": "Test provider OBS01", + "resident_provider": "Test Resident Provider", + }, + { + "code": "SU002", + "encounter_uid": "5", + "work_item_uid": "6", + "billing_provider": "Test provider SU002", + "resident_provider": "Test Resident Provider", + }, + }, + }, + }, + }, +} + +func TestGetByPath_deepPath(t *testing.T) { + val, ok := nmap.GetByPath("coding.0.details.em.code", mlMp) + assert.True(t, ok) + assert.NotEmpty(t, val) + + val, ok = nmap.GetByPath("coding.*.details", mlMp) + assert.True(t, ok) + assert.NotEmpty(t, val) + // dump.P(ok, val) + + val, ok = nmap.GetByPath("coding.*.details.em", mlMp) + //dump.P(ok, val) + assert.True(t, ok) + + val, ok = nmap.GetByPath("coding.*.details.em.code", mlMp) + //dump.P(ok, val) + assert.True(t, ok) + + val, ok = nmap.GetByPath("coding.*.details.cpt.*.encounter_uid", mlMp) + //dump.P(ok, val) + assert.True(t, ok) + + val, ok = nmap.GetByPath("coding.*.details.cpt.*.work_item_uid", mlMp) + //dump.P(ok, val) + assert.True(t, ok) +} + +func TestKeys(t *testing.T) { + mp := map[string]any{ + "key0": "v0", + "key1": "v1", + "key2": 34, + } + + ln := len(mp) + ret := nmap.Keys(mp) + assert.Len(t, ret, ln) + assert.Contains(t, ret, "key0") + assert.Contains(t, ret, "key1") + assert.Contains(t, ret, "key2") + + ret = nmap.Keys(&mp) + assert.Len(t, ret, ln) + assert.Contains(t, ret, "key0") + assert.Contains(t, ret, "key1") + + ret = nmap.Keys(struct { + a string + }{"v"}) + + assert.Len(t, ret, 0) +} + +func TestValues(t *testing.T) { + mp := map[string]any{ + "key0": "v0", + "key1": "v1", + "key2": 34, + } + + ln := len(mp) + ret := nmap.Values(mp) + + assert.Len(t, ret, ln) + assert.Contains(t, ret, "v0") + assert.Contains(t, ret, "v1") + assert.Contains(t, ret, 34) + + ret = nmap.Values(struct { + a string + }{"v"}) + + assert.Len(t, ret, 0) +} + +func TestEachAnyMap(t *testing.T) { + mp := map[string]any{ + "key0": "v0", + "key1": "v1", + "key2": 34, + } + + nmap.EachAnyMap(mp, func(k string, v any) { + assert.NotEmpty(t, k) + assert.NotEmpty(t, v) + }) + + assert.Panics(t, func() { + nmap.EachAnyMap(1, nil) + }) +} diff --git a/nmap/setval.go b/nmap/setval.go new file mode 100644 index 0000000..fc98a41 --- /dev/null +++ b/nmap/setval.go @@ -0,0 +1,339 @@ +package nmap + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" + "strconv" + "strings" +) + +// SetByPath set sub-map value by key path. +// Supports dot syntax to set deep values. +// +// For example: +// +// SetByPath("name.first", "Mat") +func SetByPath(mp *map[string]any, path string, val any) error { + return SetByKeys(mp, strings.Split(path, KeySepStr), val) +} + +// SetByKeys set sub-map value by path keys. +// Supports dot syntax to set deep values. +// +// For example: +// +// SetByKeys([]string{"name", "first"}, "Mat") +func SetByKeys(mp *map[string]any, keys []string, val any) (err error) { + kln := len(keys) + if kln == 0 { + return nil + } + + mpv := *mp + if len(mpv) == 0 { + *mp = MakeByKeys(keys, val) + return nil + } + + topK := keys[0] + if kln == 1 { + mpv[topK] = val + return nil + } + + if _, ok := mpv[topK]; !ok { + mpv[topK] = MakeByKeys(keys[1:], val) + return nil + } + + rv := reflect.ValueOf(mp).Elem() + return setMapByKeys(rv, keys, reflect.ValueOf(val)) +} + +func setMapByKeys(rv reflect.Value, keys []string, nv reflect.Value) (err error) { + if rv.Kind() != reflect.Map { + return fmt.Errorf("input parameter#rv must be a Map, but was %s", rv.Kind()) + } + + // If the map is nil, make a new map + if rv.IsNil() { + mapType := reflect.MapOf(rv.Type().Key(), rv.Type().Elem()) + rv.Set(reflect.MakeMap(mapType)) + } + + var ok bool + maxI := len(keys) - 1 + for i, key := range keys { + idx := -1 + isMap := rv.Kind() == reflect.Map + isSlice := rv.Kind() == reflect.Slice + isLast := i == len(keys)-1 + + // slice index key must be ended on the keys. + // eg: "top.arr[2]" -> "arr[2]" + if pos := strings.IndexRune(key, '['); pos > 0 { + var realKey string + if realKey, idx, ok = parseArrKeyIndex(key); ok { + // update value + key = realKey + if !isMap { + err = fmt.Errorf( + "current value#%s type is %s, cannot get sub-value by key: %s", + strings.Join(keys[i:], "."), + rv.Kind(), + key, + ) + break + } + + rftK := reflect.ValueOf(key) + tmpV := rv.MapIndex(rftK) + if !tmpV.IsValid() { + if isLast { + sliVal := reflect.MakeSlice(reflect.SliceOf(nv.Type()), idx+1, idx+1) + sliVal.Index(idx).Set(nv) + rv.SetMapIndex(rftK, sliVal) + } else { + // deep make map by keys + newVal := MakeByKeys(keys[i+1:], nv.Interface()) + mpVal := reflect.ValueOf(newVal) + + sliVal := reflect.MakeSlice(reflect.SliceOf(mpVal.Type()), idx+1, idx+1) + sliVal.Index(idx).Set(mpVal) + + rv.SetMapIndex(rftK, sliVal) + } + break + } + + // get real type: any -> map + if tmpV.Kind() == reflect.Interface { + tmpV = tmpV.Elem() + } + + if tmpV.Kind() != reflect.Slice { + err = fmt.Errorf( + "current value#%s type is %s, cannot set sub by index: %d", + strings.Join(keys[i:], "."), + tmpV.Kind(), + idx, + ) + break + } + + wantLen := idx + 1 + sliLen := tmpV.Len() + elemTyp := tmpV.Type().Elem() + + if wantLen > sliLen { + newAdd := reflect.MakeSlice(tmpV.Type(), 0, wantLen-sliLen) + for i := 0; i < wantLen-sliLen; i++ { + newAdd = reflect.Append(newAdd, reflect.New(elemTyp).Elem()) + } + + tmpV = reflect.AppendSlice(tmpV, newAdd) + } + + if !isLast { + if elemTyp.Kind() == reflect.Map { + err := setMapByKeys(tmpV.Index(idx), keys[i+1:], nv) + if err != nil { + return err + } + + // tmpV.Index(idx).Set(elemV) + rv.SetMapIndex(rftK, tmpV) + } else { + err = fmt.Errorf( + "key %s[%d] elem must be map for set sub-value by remain path: %s", + key, + idx, + strings.Join(keys[i:], "."), + ) + } + } else { + // last - set value + tmpV.Index(idx).Set(nv) + rv.SetMapIndex(rftK, tmpV) + } + break + } + } + + // set value on last key + if isLast { + if isMap { + rv.SetMapIndex(reflect.ValueOf(key), nv) + break + } + + if isSlice { + // key is slice index + if nstr.IsNumberStr(key) { + idx, _ = strconv.Atoi(key) + } + + if idx > -1 { + wantLen := idx + 1 + sliLen := rv.Len() + + if wantLen > sliLen { + elemTyp := rv.Type().Elem() + newAdd := reflect.MakeSlice(rv.Type(), 0, wantLen-sliLen) + + for i := 0; i < wantLen-sliLen; i++ { + newAdd = reflect.Append(newAdd, reflect.New(elemTyp).Elem()) + } + + if !rv.CanAddr() { + err = fmt.Errorf("cannot set value to a cannot addr slice, key: %s", key) + break + } + + rv.Set(reflect.AppendSlice(rv, newAdd)) + } + + rv.Index(idx).Set(nv) + } else { + err = fmt.Errorf("cannot set slice value by named key %q", key) + } + } else { + err = fmt.Errorf( + "cannot set sub-value for type %q(path %q, key %q)", + rv.Kind(), + strings.Join(keys[:i], "."), + key, + ) + } + + break + } + + if isMap { + rftK := reflect.ValueOf(key) + if tmpV := rv.MapIndex(rftK); tmpV.IsValid() { + var isPtr bool + // get real type: any -> map + tmpV, isPtr = getRealVal(tmpV) + if tmpV.Kind() == reflect.Map { + rv = tmpV + continue + } + + // sub is slice and is not ptr + if tmpV.Kind() == reflect.Slice { + if isPtr { + rv = tmpV + continue // to (E) + } + + // next key is index number. + nxtKey := keys[i+1] + if nstr.IsNumberStr(nxtKey) { + idx, _ = strconv.Atoi(nxtKey) + sliLen := tmpV.Len() + wantLen := idx + 1 + + if wantLen > sliLen { + elemTyp := tmpV.Type().Elem() + newAdd := reflect.MakeSlice(tmpV.Type(), 0, wantLen-sliLen) + for i := 0; i < wantLen-sliLen; i++ { + newAdd = reflect.Append(newAdd, reflect.New(elemTyp).Elem()) + } + + tmpV = reflect.AppendSlice(tmpV, newAdd) + } + + // rv = tmpV.Index(idx) // TODO + if i+1 == maxI { + tmpV.Index(idx).Set(nv) + } else { + err := setMapByKeys(tmpV.Index(idx), keys[i+1:], nv) + if err != nil { + return err + } + } + + rv.SetMapIndex(rftK, tmpV) + } else { + err = fmt.Errorf("cannot set slice value by named key %s(parent: %s)", nxtKey, key) + } + } else { + err = fmt.Errorf( + "map item type is %s(path:%q), cannot set sub-value by path %q", + tmpV.Kind(), + strings.Join(keys[0:i+1], "."), + strings.Join(keys[i+1:], "."), + ) + } + } else { + // deep make map by keys + newVal := MakeByKeys(keys[i+1:], nv.Interface()) + rv.SetMapIndex(rftK, reflect.ValueOf(newVal)) + } + + break + } else if isSlice && nstr.IsNumberStr(key) { // (E). slice from ptr slice + idx, _ = strconv.Atoi(key) + sliLen := rv.Len() + wantLen := idx + 1 + + if wantLen > sliLen { + elemTyp := rv.Type().Elem() + newAdd := reflect.MakeSlice(rv.Type(), 0, wantLen-sliLen) + for i := 0; i < wantLen-sliLen; i++ { + newAdd = reflect.Append(newAdd, reflect.New(elemTyp).Elem()) + } + + rv = reflect.AppendSlice(rv, newAdd) + } + + rv = rv.Index(idx) + } else { + err = fmt.Errorf( + "map item type is %s, cannot set sub-value by path %q", + rv.Kind(), + strings.Join(keys[i:], "."), + ) + } + } + return +} + +func getRealVal(rv reflect.Value) (reflect.Value, bool) { + // get real type: any -> map + if rv.Kind() == reflect.Interface { + rv = rv.Elem() + } + + isPtr := false + if rv.Kind() == reflect.Ptr { + isPtr = true + rv = rv.Elem() + } + + return rv, isPtr +} + +// "arr[2]" => "arr", 2, true +func parseArrKeyIndex(key string) (string, int, bool) { + pos := strings.IndexRune(key, '[') + if pos < 1 || !strings.HasSuffix(key, "]") { + return key, 0, false + } + + var idx int + var err error + + idxStr := key[pos+1 : len(key)-1] + if idxStr != "" { + idx, err = strconv.Atoi(idxStr) + if err != nil { + return key, 0, false + } + } + + key = key[:pos] + return key, idx, true +} diff --git a/nmap/setval_test.go b/nmap/setval_test.go new file mode 100644 index 0000000..5cee4d1 --- /dev/null +++ b/nmap/setval_test.go @@ -0,0 +1,188 @@ +package nmap_test + +import ( + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func makMapForSetByPath() map[string]any { + return map[string]any{ + "key0": "v0", + "key2": 34, + "key3": map[string]any{ + "k301": "v301", + "k303": []string{"v303-0", "v303-1"}, + "k304": map[string]any{ + "k3041": "v3041", + "k3042": []string{"k3042-0", "k3042-1"}, + }, + "k305": []any{ + map[string]string{ + "k3051": "v3051", + }, + }, + }, + "key4": map[string]string{ + "k401": "v401", + }, + "key6": []any{ + map[string]string{ + "k3051": "v3051", + }, + }, + "key7": nil, + } +} + +func TestSetByKeys_basic(t *testing.T) { + mp := make(map[string]any) + err := nmap.SetByKeys(&mp, []string{}, "val") + assert.NoErr(t, err) + + mp["key"] = "val1" + err = nmap.SetByKeys(&mp, []string{"key1", "k01"}, "val01") + assert.NoErr(t, err) + + assert.Eq(t, "val01", nmap.QuietGet(mp, "key1.k01")) +} + +func TestSetByKeys_emptyMap(t *testing.T) { + mp := make(map[string]any) + err := nmap.SetByKeys(&mp, []string{"k3"}, "val") + assert.NoErr(t, err) + assert.Eq(t, "val", nmap.QuietGet(mp, "k3")) + + mp = make(map[string]any) + err = nmap.SetByKeys(&mp, []string{"k5", "b"}, "v2") + // dump.P(mp) + assert.NoErr(t, err) + assert.Eq(t, "v2", nmap.QuietGet(mp, "k5.b")) +} + +func TestSetByKeys_map_add_key(t *testing.T) { + mp := makMapForSetByPath() + val := "add-new-key" + + // top level + keys1 := []string{"key501"} // ok + err1 := nmap.SetByKeys(&mp, keys1, val) + assert.NoErr(t, err1) + assert.ContainsKey(t, mp, "key501") + assert.Eq(t, val, nmap.QuietGet(mp, "key501")) + + // two level + keys2 := []string{"key3", "k30201"} // ok + err2 := nmap.SetByKeys(&mp, keys2, val) + assert.NoErr(t, err2) + assert.Eq(t, val, nmap.QuietGet(mp, "key3.k30201")) + + // more deep + keys3 := []string{"key3", "k304", "k3043"} // ok + err3 := nmap.SetByKeys(&mp, keys3, val) + assert.NoErr(t, err3) + assert.Eq(t, val, nmap.QuietGet(mp, "key3.k304.k3043")) + + // set to map[string]string + keys4 := []string{"key4", "k402"} // ok + err4 := nmap.SetByKeys(&mp, keys4, val) + assert.NoErr(t, err4) + assert.Eq(t, val, nmap.DeepGet(mp, "key4.k402")) + //dump.Println(mp) +} + +func TestSetByKeys_map_up_val(t *testing.T) { + mp := makMapForSetByPath() + val := "set-new-val" + + keys1 := []string{"key0"} // ok + err1 := nmap.SetByKeys(&mp, keys1, val) + assert.NoErr(t, err1) + assert.Eq(t, val, nmap.QuietGet(mp, "key0")) + + keys2 := []string{"key3", "k301"} // ok + err2 := nmap.SetByKeys(&mp, keys2, val) + assert.NoErr(t, err2) + assert.Eq(t, val, nmap.QuietGet(mp, "key3.k301")) + + keys4 := []string{"key4", "k401"} // ok + err4 := nmap.SetByKeys(&mp, keys4, val) + assert.NoErr(t, err4) + assert.Eq(t, val, nmap.DeepGet(mp, "key4.k401")) + //dump.Println(mp) +} + +func TestSetByKeys_slice_upAdd_method1(t *testing.T) { + mp := makMapForSetByPath() + + nVal := "set-new-value" + keys3 := []string{"key3", "k303", "1"} // ok + err3 := nmap.SetByKeys(&mp, keys3, nVal) + assert.NoErr(t, err3) + assert.Eq(t, nVal, nmap.QuietGet(mp, "key3.k303.1")) + + nVal2 := "add-new-item" + keys4 := []string{"key3", "k303", "2"} // ok + err4 := nmap.SetByKeys(&mp, keys4, nVal2) + assert.NoErr(t, err4) + assert.Eq(t, nVal2, nmap.QuietGet(mp, "key3.k303.2")) + //dump.Println(mp) +} + +func TestSetByKeys_slice_upAdd_method2(t *testing.T) { + mp := makMapForSetByPath() + nVal := "new-value" + + keys2 := []string{"key3", "k303[1]"} // ok + err2 := nmap.SetByKeys(&mp, keys2, nVal) + assert.NoErr(t, err2) + assert.Eq(t, nVal, nmap.QuietGet(mp, "key3.k303.1")) + + nVal2 := "add-new-item" + keys3 := []string{"key3", "k303[2]"} // ok + err3 := nmap.SetByKeys(&mp, keys3, nVal2) + assert.NoErr(t, err3) + assert.Eq(t, nVal2, nmap.QuietGet(mp, "key3.k303.2")) + + //dump.Println(mp) +} + +func TestSetByPath(t *testing.T) { + mp := map[string]any{ + "key0": "v0", + "key1": "v1", + "key2": 34, + } + + err := nmap.SetByPath(&mp, "key0", "v00") + assert.NoErr(t, err) + assert.ContainsKey(t, mp, "key0") + assert.Eq(t, "v00", mp["key0"]) + + err = nmap.SetByPath(&mp, "key3", map[string]any{ + "k301": "v301", + "k302": 234, + "k303": []string{"v303-1", "v303-2"}, + "k304": nil, + }) + + // dump.P(mp) + assert.NoErr(t, err) + assert.ContainsKeys(t, mp, []string{"key3"}) + assert.ContainsKeys(t, mp["key3"], []string{"k301", "k302", "k303", "k304"}) + + err = nmap.SetByPath(&mp, "key4", map[string]string{ + "k401": "v401", + }) + assert.NoErr(t, err) + assert.ContainsKey(t, mp, "key3") + + val, ok := nmap.GetByPath("key4.k401", mp) + assert.True(t, ok) + assert.Eq(t, "v401", val) + + err = nmap.SetByPath(&mp, "key4.k402", "v402") + assert.NoErr(t, err) + + //dump.P(mp) +} diff --git a/nmap/smap.go b/nmap/smap.go new file mode 100644 index 0000000..425e8cc --- /dev/null +++ b/nmap/smap.go @@ -0,0 +1,126 @@ +package nmap + +import ( + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nstr" +) + +// SMap is alias of map[string]string +type SMap map[string]string + +// IsEmpty of the data map +func (m SMap) IsEmpty() bool { + return len(m) == 0 +} + +// Has key on the data map +func (m SMap) Has(key string) bool { + _, ok := m[key] + return ok +} + +// HasValue on the data map +func (m SMap) HasValue(val string) bool { + for _, v := range m { + if v == val { + return true + } + } + return false +} + +// Value get from the data map +func (m SMap) Value(key string) (string, bool) { + val, ok := m[key] + return val, ok +} + +// Default get value by key. if not found, return defVal +func (m SMap) Default(key, defVal string) string { + if val, ok := m[key]; ok { + return val + } + return defVal +} + +// Get value by key +func (m SMap) Get(key string) string { + return m[key] +} + +// Int value get +func (m SMap) Int(key string) int { + if val, ok := m[key]; ok { + return nmath.QuietInt(val) + } + return 0 +} + +// Int64 value get +func (m SMap) Int64(key string) int64 { + if val, ok := m[key]; ok { + return nmath.QuietInt64(val) + } + return 0 +} + +// Str value get +func (m SMap) Str(key string) string { + return m[key] +} + +// Bool value get +func (m SMap) Bool(key string) bool { + if val, ok := m[key]; ok { + return nstr.QuietBool(val) + } + return false +} + +// Ints value to []int +func (m SMap) Ints(key string) []int { + if val, ok := m[key]; ok { + return nstr.Ints(val, ValSepStr) + } + return nil +} + +// Strings value to []string +func (m SMap) Strings(key string) (ss []string) { + if val, ok := m[key]; ok { + return nstr.ToSlice(val, ValSepStr) + } + return +} + +// Keys of the string-map +func (m SMap) Keys() []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// Values of the string-map +func (m SMap) Values() []string { + ss := make([]string, 0, len(m)) + for _, v := range m { + ss = append(ss, v) + } + return ss +} + +// ToKVPairs slice convert. eg: {k1:v1,k2:v2} => {k1,v1,k2,v2} +func (m SMap) ToKVPairs() []string { + pairs := make([]string, 0, len(m)*2) + for k, v := range m { + pairs = append(pairs, k, v) + } + return pairs +} + +// String data to string +func (m SMap) String() string { + return ToString2(m) +} diff --git a/nmap/smap_test.go b/nmap/smap_test.go new file mode 100644 index 0000000..91bf630 --- /dev/null +++ b/nmap/smap_test.go @@ -0,0 +1,61 @@ +package nmap_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestSMap_usage(t *testing.T) { + mp := nmap.SMap{ + "k1": "23", + "k2": "ab", + "k3": "true", + "k4": "1,2", + } + + assert.True(t, mp.Has("k1")) + assert.True(t, mp.HasValue("true")) + assert.True(t, mp.Bool("k3")) + assert.False(t, mp.IsEmpty()) + assert.False(t, mp.HasValue("not-exist")) + assert.Len(t, mp.Keys(), 4) + assert.Len(t, mp.Values(), 4) + + val, ok := mp.Value("k2") + assert.True(t, ok) + assert.Eq(t, "ab", val) + + // int + assert.Eq(t, 23, mp.Int("k1")) + assert.Eq(t, int64(23), mp.Int64("k1")) + + // str + assert.Eq(t, "23", mp.Str("k1")) + assert.Eq(t, "ab", mp.Get("k2")) + + // slice + assert.Eq(t, []int{1, 2}, mp.Ints("k4")) + assert.Eq(t, []string{"1", "2"}, mp.Strings("k4")) + assert.Nil(t, mp.Strings("not-exist")) + + // not exists + assert.False(t, mp.Bool("notExists")) + assert.Eq(t, 0, mp.Int("notExists")) + assert.Eq(t, int64(0), mp.Int64("notExists")) + assert.Eq(t, "", mp.Str("notExists")) + assert.Empty(t, mp.Ints("notExists")) +} + +func TestSMap_ToKVPairs(t *testing.T) { + mp := nmap.SMap{ + "k1": "23", + "k2": "ab", + } + arr := mp.ToKVPairs() + assert.Len(t, arr, 4) + str := fmt.Sprint(arr) + assert.StrContains(t, str, "k1 23") + assert.StrContains(t, str, "k2 ab") +} diff --git a/nmap/util.go b/nmap/util.go new file mode 100644 index 0000000..2e3e607 --- /dev/null +++ b/nmap/util.go @@ -0,0 +1,128 @@ +package nmap + +import ( + "git.noahlan.cn/noahlan/ntool/narr" + "reflect" + "strings" +) + +// Key, value sep char const +const ( + ValSepStr = "," + ValSepChar = ',' + KeySepStr = "." + KeySepChar = '.' +) + +// SimpleMerge simple merge two data map by string key. +// will merge the src to dst map +func SimpleMerge(src, dst map[string]any) map[string]any { + if len(src) == 0 { + return dst + } + + if len(dst) == 0 { + return src + } + + for key, val := range src { + dst[key] = val + } + return dst +} + +// func DeepMerge(src, dst map[string]any, deep int) map[string]any { TODO +// } + +// MergeSMap simple merge two string map. merge src to dst map +func MergeSMap(src, dst map[string]string, ignoreCase bool) map[string]string { + return MergeStringMap(src, dst, ignoreCase) +} + +// MergeStringMap simple merge two string map. merge src to dst map +func MergeStringMap(src, dst map[string]string, ignoreCase bool) map[string]string { + if len(src) == 0 { + return dst + } + if len(dst) == 0 { + return src + } + + for k, v := range src { + if ignoreCase { + k = strings.ToLower(k) + } + + dst[k] = v + } + return dst +} + +// MakeByPath build new value by key names +// +// Example: +// +// "site.info" +// -> +// map[string]any { +// site: {info: val} +// } +// +// // case 2, last key is slice: +// "site.tags[1]" +// -> +// map[string]any { +// site: {tags: [val]} +// } +func MakeByPath(path string, val any) (mp map[string]any) { + return MakeByKeys(strings.Split(path, KeySepStr), val) +} + +// MakeByKeys build new value by key names +// +// Example: +// +// // case 1: +// []string{"site", "info"} +// -> +// map[string]any { +// site: {info: val} +// } +// +// // case 2, last key is slice: +// []string{"site", "tags[1]"} +// -> +// map[string]any { +// site: {tags: [val]} +// } +func MakeByKeys(keys []string, val any) (mp map[string]any) { + size := len(keys) + + // if last key contains slice index, make slice wrap the val + lastKey := keys[size-1] + if newK, idx, ok := parseArrKeyIndex(lastKey); ok { + // valTyp := reflect.TypeOf(val) + sliTyp := reflect.SliceOf(reflect.TypeOf(val)) + sliVal := reflect.MakeSlice(sliTyp, idx+1, idx+1) + sliVal.Index(idx).Set(reflect.ValueOf(val)) + + // update val and last key + val = sliVal.Interface() + keys[size-1] = newK + } + + if size == 1 { + return map[string]any{keys[0]: val} + } + + // multi nodes + narr.Reverse(keys) + for _, p := range keys { + if mp == nil { + mp = map[string]any{p: val} + } else { + mp = map[string]any{p: mp} + } + } + return +} diff --git a/nmap/util_test.go b/nmap/util_test.go new file mode 100644 index 0000000..6f5cf6f --- /dev/null +++ b/nmap/util_test.go @@ -0,0 +1,74 @@ +package nmap_test + +import ( + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "reflect" + "testing" +) + +func TestSimpleMerge(t *testing.T) { + src := map[string]any{"A": "v0"} + dst := map[string]any{"A": "v1", "B": "v2"} + ret := nmap.SimpleMerge(src, dst) + assert.Len(t, ret, 2) + assert.Eq(t, "v0", ret["A"]) + + dst = map[string]any{"A": "v1", "B": "v2"} + ret = nmap.SimpleMerge(nil, dst) + assert.Eq(t, "v1", ret["A"]) + + ret = nmap.SimpleMerge(src, nil) + assert.Eq(t, "v0", ret["A"]) +} + +func TestMergeStringMap(t *testing.T) { + ret := nmap.MergeSMap(map[string]string{"A": "v0"}, map[string]string{"A": "v1"}, false) + assert.Eq(t, map[string]string{"A": "v0"}, ret) + + ret = nmap.MergeSMap(map[string]string{"A": "v0"}, map[string]string{"a": "v1"}, true) + assert.Eq(t, map[string]string{"a": "v0"}, ret) +} + +func TestMakeByPath(t *testing.T) { + mp := nmap.MakeByPath("top.sub", "val") + + assert.NotEmpty(t, mp) + assert.ContainsKey(t, mp, "top") + assert.IsKind(t, reflect.Map, mp["top"]) + assert.Eq(t, "val", nmap.DeepGet(mp, "top.sub")) + + mp = nmap.MakeByPath("top.arr[1]", "val") + //dump.P(mp) + assert.NotEmpty(t, mp) + assert.ContainsKey(t, mp, "top") + assert.Eq(t, "{top:map[arr:[ val]]}", nmap.ToString(mp)) + assert.Eq(t, []string{"", "val"}, nmap.DeepGet(mp, "top.arr")) + assert.Eq(t, "val", nmap.DeepGet(mp, "top.arr.1")) +} + +func TestMakeByKeys(t *testing.T) { + mp := nmap.MakeByKeys([]string{"top"}, "val") + assert.NotEmpty(t, mp) + assert.ContainsKey(t, mp, "top") + assert.Eq(t, "val", mp["top"]) + + mp = nmap.MakeByKeys([]string{"top", "sub"}, "val") + assert.NotEmpty(t, mp) + assert.ContainsKey(t, mp, "top") + assert.IsKind(t, reflect.Map, mp["top"]) + + mp = nmap.MakeByKeys([]string{"top_arr[]"}, 234) + // dump.P(mp) + assert.NotEmpty(t, mp) + assert.IsKind(t, reflect.Slice, mp["top_arr"]) + assert.Eq(t, 234, nmap.DeepGet(mp, "top_arr.0")) + + mp = nmap.MakeByKeys([]string{"top", "arr[1]"}, "val") + //dump.P(mp) + assert.NotEmpty(t, mp) + assert.ContainsKey(t, mp, "top") + assert.Eq(t, "{top:map[arr:[ val]]}", nmap.ToString(mp)) + assert.Eq(t, []string{"", "val"}, nmap.DeepGet(mp, "top.arr")) + assert.Eq(t, "val", nmap.DeepGet(mp, "top.arr.1")) +} diff --git a/nmath/check.go b/nmath/check.go new file mode 100644 index 0000000..eef5056 --- /dev/null +++ b/nmath/check.go @@ -0,0 +1,97 @@ +package nmath + +import "git.noahlan.cn/noahlan/ntool/ndef" + +// Compare any intX,floatX value by given op. returns `srcVal op(=,!=,<,<=,>,>=) dstVal` +// +// Usage: +// +// nmath.Compare(2, 3, ">") // false +// nmath.Compare(2, 1.3, ">") // true +// nmath.Compare(2.2, 1.3, ">") // true +// nmath.Compare(2.1, 2, ">") // true +func Compare(srcVal, dstVal any, op string) (ok bool) { + if srcVal == nil || dstVal == nil { + return false + } + + // float + if srcFlt, ok := srcVal.(float64); ok { + if dstFlt, err := ToFloat(dstVal); err == nil { + return CompFloat(srcFlt, dstFlt, op) + } + return false + } + + if srcFlt, ok := srcVal.(float32); ok { + if dstFlt, err := ToFloat(dstVal); err == nil { + return CompFloat(float64(srcFlt), dstFlt, op) + } + return false + } + + // as int64 + srcInt, err := ToInt64(srcVal) + if err != nil { + return false + } + + dstInt, err := ToInt64(dstVal) + if err != nil { + return false + } + + return CompInt64(srcInt, dstInt, op) +} + +// CompInt compare int,uint value. returns `srcVal op(=,!=,<,<=,>,>=) dstVal` +func CompInt[T ndef.XInt](srcVal, dstVal T, op string) (ok bool) { + return CompValue(srcVal, dstVal, op) +} + +// CompInt64 compare int64 value. returns `srcVal op(=,!=,<,<=,>,>=) dstVal` +func CompInt64(srcVal, dstVal int64, op string) bool { + return CompValue(srcVal, dstVal, op) +} + +// CompFloat compare float64,float32 value. returns `srcVal op(=,!=,<,<=,>,>=) dstVal` +func CompFloat[T ndef.Float](srcVal, dstVal T, op string) (ok bool) { + return CompValue(srcVal, dstVal, op) +} + +// CompValue compare intX,uintX,floatX value. returns `srcVal op(=,!=,<,<=,>,>=) dstVal` +func CompValue[T ndef.XIntOrFloat](srcVal, dstVal T, op string) (ok bool) { + switch op { + case "<", "lt": + ok = srcVal < dstVal + case "<=", "lte": + ok = srcVal <= dstVal + case ">", "gt": + ok = srcVal > dstVal + case ">=", "gte": + ok = srcVal >= dstVal + case "=", "eq": + ok = srcVal == dstVal + case "!=", "ne", "neq": + ok = srcVal != dstVal + } + return +} + +// InRange check if val in int/float range [min, max] +func InRange[T ndef.IntOrFloat](val, min, max T) bool { + return val >= min && val <= max +} + +// OutRange check if val not in int/float range [min, max] +func OutRange[T ndef.IntOrFloat](val, min, max T) bool { + return val < min || val > max +} + +// InUintRange check if val in unit range [min, max] +func InUintRange[T ndef.Uint](val, min, max T) bool { + if max == 0 { + return val >= min + } + return val >= min && val <= max +} diff --git a/nmath/check_test.go b/nmath/check_test.go new file mode 100644 index 0000000..5a5ebc8 --- /dev/null +++ b/nmath/check_test.go @@ -0,0 +1,67 @@ +package nmath_test + +import ( + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestCompare(t *testing.T) { + tests := []struct { + x, y any + op string + }{ + {2, 2, ndef.OpEq}, + {2, 3, ndef.OpNeq}, + {2, 3, ndef.OpLt}, + {2, 3, ndef.OpLte}, + {2, 2, ndef.OpLte}, + {2, 1, ndef.OpGt}, + {2, 2, ndef.OpGte}, + {2, 1, ndef.OpGte}, + {2, "1", ndef.OpGte}, + {2.2, 2.2, ndef.OpEq}, + {2.2, 3.1, ndef.OpNeq}, + {2.3, 3.2, ndef.OpLt}, + {2.3, 3.3, ndef.OpLte}, + {2.3, 2.3, ndef.OpLte}, + {2.3, 1.3, ndef.OpGt}, + {2.3, 2.3, ndef.OpGte}, + {2.3, 1.3, ndef.OpGte}, + } + + for _, test := range tests { + assert.True(t, nmath.Compare(test.x, test.y, test.op)) + } + + assert.False(t, nmath.Compare(2, 3, ndef.OpGt)) + assert.False(t, nmath.Compare(nil, 3, ndef.OpGt)) + assert.False(t, nmath.Compare(2, nil, ndef.OpGt)) + assert.False(t, nmath.Compare("abc", 3, ndef.OpGt)) + assert.False(t, nmath.Compare(2, "def", ndef.OpGt)) + + assert.True(t, nmath.CompInt64(2, 3, ndef.OpLt)) +} + +func TestInRange(t *testing.T) { + assert.True(t, nmath.InRange(1, 1, 2)) + assert.True(t, nmath.InRange(1, 1, 1)) + assert.False(t, nmath.InRange(1, 2, 1)) + assert.False(t, nmath.InRange(1, 2, 2)) + + assert.True(t, nmath.InRange(1.1, 1.1, 2.2)) + assert.True(t, nmath.InRange(1.1, 1.1, 1.1)) + assert.False(t, nmath.InRange(1.1, 2.2, 1.1)) + + // test for OutRange() + assert.False(t, nmath.OutRange(1, 1, 2)) + assert.False(t, nmath.OutRange(1, 1, 1)) + assert.True(t, nmath.OutRange(1, 2, 10)) + + // test for InUintRange() + assert.True(t, nmath.InUintRange[uint](1, 1, 2)) + assert.True(t, nmath.InUintRange[uint](1, 1, 1)) + assert.True(t, nmath.InUintRange[uint](1, 1, 0)) + assert.False(t, nmath.InUintRange[uint](1, 2, 1)) +} diff --git a/nmath/convert.go b/nmath/convert.go new file mode 100644 index 0000000..ba38cdc --- /dev/null +++ b/nmath/convert.go @@ -0,0 +1,404 @@ +package nmath + +import ( + "encoding/json" + "fmt" + "git.noahlan.cn/noahlan/ntool/ndef" + "strconv" + "strings" + "time" +) + +/************************************************************* + * convert value to int + *************************************************************/ + +// Int convert value to int +func Int(in any) (int, error) { + return ToInt(in) +} + +// QuietInt convert value to int, will ignore error +func QuietInt(in any) int { + val, _ := ToInt(in) + return val +} + +// MustInt convert value to int, will panic on error +func MustInt(in any) int { + val, _ := ToInt(in) + return val +} + +// IntOrPanic convert value to int, will panic on error +func IntOrPanic(in any) int { + val, err := ToInt(in) + if err != nil { + panic(err) + } + return val +} + +// IntOrErr convert value to int, return error on failed +func IntOrErr(in any) (iVal int, err error) { + return ToInt(in) +} + +// ToInt convert value to int, return error on failed +func ToInt(in any) (iVal int, err error) { + switch tVal := in.(type) { + case nil: + iVal = 0 + case int: + iVal = tVal + case int8: + iVal = int(tVal) + case int16: + iVal = int(tVal) + case int32: + iVal = int(tVal) + case int64: + iVal = int(tVal) + case uint: + iVal = int(tVal) + case uint8: + iVal = int(tVal) + case uint16: + iVal = int(tVal) + case uint32: + iVal = int(tVal) + case uint64: + iVal = int(tVal) + case float32: + iVal = int(tVal) + case float64: + iVal = int(tVal) + case time.Duration: + iVal = int(tVal) + case string: + iVal, err = strconv.Atoi(strings.TrimSpace(tVal)) + case json.Number: + var i64 int64 + i64, err = tVal.Int64() + iVal = int(i64) + default: + err = ndef.ErrConvType + } + return +} + +// StrInt convert. +func StrInt(s string) int { + iVal, _ := strconv.Atoi(strings.TrimSpace(s)) + return iVal +} + +/************************************************************* + * convert value to uint + *************************************************************/ + +// Uint convert string to uint, return error on failed +func Uint(in any) (uint64, error) { + return ToUint(in) +} + +// QuietUint convert string to uint, will ignore error +func QuietUint(in any) uint64 { + val, _ := ToUint(in) + return val +} + +// MustUint convert string to uint, will panic on error +func MustUint(in any) uint64 { + val, _ := ToUint(in) + return val +} + +// UintOrErr convert value to uint, return error on failed +func UintOrErr(in any) (uint64, error) { + return ToUint(in) +} + +// ToUint convert value to uint, return error on failed +func ToUint(in any) (u64 uint64, err error) { + switch tVal := in.(type) { + case nil: + u64 = 0 + case int: + u64 = uint64(tVal) + case int8: + u64 = uint64(tVal) + case int16: + u64 = uint64(tVal) + case int32: + u64 = uint64(tVal) + case int64: + u64 = uint64(tVal) + case uint: + u64 = uint64(tVal) + case uint8: + u64 = uint64(tVal) + case uint16: + u64 = uint64(tVal) + case uint32: + u64 = uint64(tVal) + case uint64: + u64 = tVal + case float32: + u64 = uint64(tVal) + case float64: + u64 = uint64(tVal) + case time.Duration: + u64 = uint64(tVal) + case json.Number: + var i64 int64 + i64, err = tVal.Int64() + u64 = uint64(i64) + case string: + u64, err = strconv.ParseUint(strings.TrimSpace(tVal), 10, 0) + default: + err = ndef.ErrConvType + } + return +} + +/************************************************************* + * convert value to int64 + *************************************************************/ + +// Int64 convert string to int64, return error on failed +func Int64(in any) (int64, error) { + return ToInt64(in) +} + +// SafeInt64 convert value to int64, will ignore error +func SafeInt64(in any) int64 { + i64, _ := ToInt64(in) + return i64 +} + +// QuietInt64 convert value to int64, will ignore error +func QuietInt64(in any) int64 { + i64, _ := ToInt64(in) + return i64 +} + +// MustInt64 convert value to int64, will panic on error +func MustInt64(in any) int64 { + i64, _ := ToInt64(in) + return i64 +} + +// TODO StrictInt64,AsInt64 strict convert to int64 + +// Int64OrErr convert string to int64, return error on failed +func Int64OrErr(in any) (int64, error) { + return ToInt64(in) +} + +// ToInt64 convert string to int64, return error on failed +func ToInt64(in any) (i64 int64, err error) { + switch tVal := in.(type) { + case nil: + i64 = 0 + case string: + i64, err = strconv.ParseInt(strings.TrimSpace(tVal), 10, 0) + case int: + i64 = int64(tVal) + case int8: + i64 = int64(tVal) + case int16: + i64 = int64(tVal) + case int32: + i64 = int64(tVal) + case int64: + i64 = tVal + case uint: + i64 = int64(tVal) + case uint8: + i64 = int64(tVal) + case uint16: + i64 = int64(tVal) + case uint32: + i64 = int64(tVal) + case uint64: + i64 = int64(tVal) + case float32: + i64 = int64(tVal) + case float64: + i64 = int64(tVal) + case time.Duration: + i64 = int64(tVal) + case json.Number: + i64, err = tVal.Int64() + default: + err = ndef.ErrConvType + } + return +} + +/************************************************************* + * convert value to float + *************************************************************/ + +// QuietFloat convert value to float64, will ignore error +func QuietFloat(in any) float64 { + val, _ := ToFloat(in) + return val +} + +// FloatOrPanic convert value to float64, will panic on error +func FloatOrPanic(in any) float64 { + val, err := ToFloat(in) + if err != nil { + panic(err) + } + return val +} + +// MustFloat convert value to float64, will panic on error +func MustFloat(in any) float64 { + val, err := ToFloat(in) + if err != nil { + panic(err) + } + return val +} + +// Float convert value to float64, return error on failed +func Float(in any) (float64, error) { + return ToFloat(in) +} + +// FloatOrErr convert value to float64, return error on failed +func FloatOrErr(in any) (float64, error) { + return ToFloat(in) +} + +// ToFloat convert value to float64, return error on failed +func ToFloat(in any) (f64 float64, err error) { + switch tVal := in.(type) { + case nil: + f64 = 0 + case string: + f64, err = strconv.ParseFloat(strings.TrimSpace(tVal), 64) + case int: + f64 = float64(tVal) + case int8: + f64 = float64(tVal) + case int16: + f64 = float64(tVal) + case int32: + f64 = float64(tVal) + case int64: + f64 = float64(tVal) + case uint: + f64 = float64(tVal) + case uint8: + f64 = float64(tVal) + case uint16: + f64 = float64(tVal) + case uint32: + f64 = float64(tVal) + case uint64: + f64 = float64(tVal) + case float32: + f64 = float64(tVal) + case float64: + f64 = tVal + case time.Duration: + f64 = float64(tVal) + case json.Number: + f64, err = tVal.Float64() + default: + err = ndef.ErrConvType + } + return +} + +/************************************************************* + * convert intX/floatX to string + *************************************************************/ + +// StringOrPanic convert intX/floatX value to string, will panic on error +func StringOrPanic(val any) string { + str, err := TryToString(val, true) + if err != nil { + panic(err) + } + return str +} + +// MustString convert intX/floatX value to string, will panic on error +func MustString(val any) string { + return StringOrPanic(val) +} + +// ToString convert intX/floatX value to string, return error on failed +func ToString(val any) (string, error) { + return TryToString(val, true) +} + +// StringOrErr convert intX/floatX value to string, return error on failed +func StringOrErr(val any) (string, error) { + return TryToString(val, true) +} + +// QuietString convert intX/floatX value to string, other type convert by fmt.Sprint +func QuietString(val any) string { + str, _ := TryToString(val, false) + return str +} + +// String convert intX/floatX value to string, other type convert by fmt.Sprint +func String(val any) string { + str, _ := TryToString(val, false) + return str +} + +// TryToString try convert intX/floatX value to string +// +// if defaultAsErr is False, will use fmt.Sprint convert other type +func TryToString(val any, defaultAsErr bool) (str string, err error) { + if val == nil { + return + } + + switch value := val.(type) { + case int: + str = strconv.Itoa(value) + case int8: + str = strconv.Itoa(int(value)) + case int16: + str = strconv.Itoa(int(value)) + case int32: // same as `rune` + str = strconv.Itoa(int(value)) + case int64: + str = strconv.FormatInt(value, 10) + case uint: + str = strconv.FormatUint(uint64(value), 10) + case uint8: + str = strconv.FormatUint(uint64(value), 10) + case uint16: + str = strconv.FormatUint(uint64(value), 10) + case uint32: + str = strconv.FormatUint(uint64(value), 10) + case uint64: + str = strconv.FormatUint(value, 10) + case float32: + str = strconv.FormatFloat(float64(value), 'f', -1, 32) + case float64: + str = strconv.FormatFloat(value, 'f', -1, 64) + case time.Duration: + str = strconv.FormatInt(int64(value), 10) + case fmt.Stringer: + str = value.String() + default: + if defaultAsErr { + err = ndef.ErrConvType + } else { + str = fmt.Sprint(value) + } + } + return +} diff --git a/nmath/convert_test.go b/nmath/convert_test.go new file mode 100644 index 0000000..6ded17f --- /dev/null +++ b/nmath/convert_test.go @@ -0,0 +1,164 @@ +package nmath_test + +import ( + "encoding/json" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" + "time" +) + +func TestToInt(t *testing.T) { + is := assert.New(t) + + tests := []any{ + 2, + int8(2), int16(2), int32(2), int64(2), + uint(2), uint8(2), uint16(2), uint32(2), uint64(2), + float32(2.2), 2.3, + "2", + time.Duration(2), + json.Number("2"), + } + errTests := []any{ + nil, + "2a", + []int{1}, + } + + // To int + intVal, err := nmath.Int("2") + is.Nil(err) + is.Eq(2, intVal) + + intVal, err = nmath.ToInt("-2") + is.Nil(err) + is.Eq(-2, intVal) + + is.Eq(2, nmath.StrInt("2")) + + intVal, err = nmath.IntOrErr("-2") + is.Nil(err) + is.Eq(-2, intVal) + + is.Eq(-2, nmath.MustInt("-2")) + for _, in := range tests { + is.Eq(2, nmath.MustInt(in)) + is.Eq(2, nmath.QuietInt(in)) + } + for _, in := range errTests { + is.Eq(0, nmath.MustInt(in)) + } + + // To uint + uintVal, err := nmath.Uint("2") + is.Nil(err) + is.Eq(uint64(2), uintVal) + + uintVal, err = nmath.UintOrErr("2") + is.Nil(err) + is.Eq(uint64(2), uintVal) + + _, err = nmath.ToUint("-2") + is.Err(err) + + is.Eq(uint64(0), nmath.MustUint("-2")) + for _, in := range tests { + is.Eq(uint64(2), nmath.MustUint(in)) + } + for _, in := range errTests { + is.Eq(uint64(0), nmath.QuietUint(in)) + is.Eq(uint64(0), nmath.MustUint(in)) + } + + // To int64 + i64Val, err := nmath.ToInt64("2") + is.Nil(err) + is.Eq(int64(2), i64Val) + + i64Val, err = nmath.Int64("-2") + is.Nil(err) + is.Eq(int64(-2), i64Val) + + i64Val, err = nmath.Int64OrErr("-2") + is.Nil(err) + is.Eq(int64(-2), i64Val) + + for _, in := range tests { + is.Eq(int64(2), nmath.MustInt64(in)) + } + for _, in := range errTests { + is.Eq(int64(0), nmath.MustInt64(in)) + is.Eq(int64(0), nmath.QuietInt64(in)) + is.Eq(int64(0), nmath.SafeInt64(in)) + } +} + +func TestToString(t *testing.T) { + is := assert.New(t) + + tests := []any{ + 2, + int8(2), int16(2), int32(2), int64(2), + uint(2), uint8(2), uint16(2), uint32(2), uint64(2), + float32(2), float64(2), + // "2", + time.Duration(2), + json.Number("2"), + } + + for _, in := range tests { + is.Eq("2", nmath.String(in)) + is.Eq("2", nmath.QuietString(in)) + is.Eq("2", nmath.MustString(in)) + val, err := nmath.ToString(in) + is.NoErr(err) + is.Eq("2", val) + } + + val, err := nmath.StringOrErr(2) + is.NoErr(err) + is.Eq("2", val) + + val, err = nmath.ToString(nil) + is.NoErr(err) + is.Eq("", val) + + is.Panics(func() { + nmath.MustString("2") + }) +} + +func TestToFloat(t *testing.T) { + is := assert.New(t) + + tests := []any{ + 2, + int8(2), int16(2), int32(2), int64(2), + uint(2), uint8(2), uint16(2), uint32(2), uint64(2), + float32(2), float64(2), + "2", + time.Duration(2), + json.Number("2"), + } + for _, in := range tests { + is.Eq(float64(2), nmath.MustFloat(in)) + } + + is.Eq(123.5, nmath.MustFloat("123.5")) + is.Eq(123.5, nmath.QuietFloat("123.5")) + is.Panics(func() { nmath.MustFloat("invalid") }, "ok") + is.Eq(float64(0), nmath.QuietFloat("invalid")) + + fltVal, err := nmath.ToFloat("123.5") + is.Nil(err) + is.Eq(123.5, fltVal) + + fltVal, err = nmath.Float("-123.5") + is.Nil(err) + is.Eq(-123.5, fltVal) + + fltVal, err = nmath.FloatOrErr("-123.5") + is.Nil(err) + is.Eq(-123.5, fltVal) +} diff --git a/nmath/number.go b/nmath/number.go new file mode 100644 index 0000000..ba05a38 --- /dev/null +++ b/nmath/number.go @@ -0,0 +1,14 @@ +package nmath + +// IsNumeric returns true if the given character is a numeric, otherwise false. +func IsNumeric(c byte) bool { + return c >= '0' && c <= '9' +} + +// Percent returns a values percent of the total +func Percent(val, total int) float64 { + if total == 0 { + return float64(0) + } + return (float64(val) / float64(total)) * 100 +} diff --git a/nmath/number_test.go b/nmath/number_test.go new file mode 100644 index 0000000..4d3cd43 --- /dev/null +++ b/nmath/number_test.go @@ -0,0 +1,36 @@ +package nmath_test + +import ( + "git.noahlan.cn/noahlan/ntool/ngo" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "git.noahlan.cn/noahlan/ntool/ntime" + "testing" + "time" +) + +func TestIsNumeric(t *testing.T) { + assert.True(t, nmath.IsNumeric('3')) + assert.False(t, nmath.IsNumeric('a')) +} + +func TestPercent(t *testing.T) { + assert.Eq(t, float64(34), nmath.Percent(34, 100)) + assert.Eq(t, float64(0), nmath.Percent(34, 0)) + assert.Eq(t, float64(-100), nmath.Percent(34, -34)) +} + +func TestElapsedTime(t *testing.T) { + nt := time.Now().Add(-time.Second * 3) + num := ntime.ElapsedTime(nt) + + assert.Eq(t, 3000, int(nmath.MustFloat(num))) +} + +func TestDataSize(t *testing.T) { + assert.Eq(t, "3.38K", ngo.DataSize(3456)) +} + +func TestHowLongAgo(t *testing.T) { + assert.Eq(t, "57 mins", ntime.HowLongAgo(3456)) +} diff --git a/nmath/util.go b/nmath/util.go new file mode 100644 index 0000000..8d75ce0 --- /dev/null +++ b/nmath/util.go @@ -0,0 +1,83 @@ +package nmath + +import ( + "git.noahlan.cn/noahlan/ntool/ndef" + "math" +) + +// Min compare two value and return max value +func Min[T ndef.XIntOrFloat](x, y T) T { + if x < y { + return x + } + return y +} + +// Max compare two value and return max value +func Max[T ndef.XIntOrFloat](x, y T) T { + if x > y { + return x + } + return y +} + +// SwapMin compare and always return [min, max] value +func SwapMin[T ndef.XIntOrFloat](x, y T) (T, T) { + if x < y { + return x, y + } + return y, x +} + +// SwapMax compare and always return [max, min] value +func SwapMax[T ndef.XIntOrFloat](x, y T) (T, T) { + if x > y { + return x, y + } + return y, x +} + +// MaxInt compare and return max value +func MaxInt(x, y int) int { + if x > y { + return x + } + return y +} + +// SwapMaxInt compare and return max, min value +func SwapMaxInt(x, y int) (int, int) { + if x > y { + return x, y + } + return y, x +} + +// MaxI64 compare and return max value +func MaxI64(x, y int64) int64 { + if x > y { + return x + } + return y +} + +// SwapMaxI64 compare and return max, min value +func SwapMaxI64(x, y int64) (int64, int64) { + if x > y { + return x, y + } + return y, x +} + +// MaxFloat compare and return max value +func MaxFloat(x, y float64) float64 { + return math.Max(x, y) +} + +// OrElse return s OR nv(new-value) on s is empty +func OrElse[T ndef.XIntOrFloat](in, nv T) T { + if in != 0 { + return in + } + return nv +} diff --git a/nmath/util_test.go b/nmath/util_test.go new file mode 100644 index 0000000..f30e422 --- /dev/null +++ b/nmath/util_test.go @@ -0,0 +1,61 @@ +package nmath_test + +import ( + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestMaxFloat(t *testing.T) { + assert.Eq(t, float64(3), nmath.MaxFloat(2, 3)) + assert.Eq(t, 3.3, nmath.MaxFloat(2.1, 3.3)) + + assert.Eq(t, 3.3, nmath.Max(2.1, 3.3)) + assert.Eq(t, 3.3, nmath.Max(3.3, 2.1)) + + assert.Eq(t, 2.1, nmath.Min(2.1, 3.3)) + assert.Eq(t, 2.1, nmath.Min(3.3, 2.1)) +} + +func TestMaxI64(t *testing.T) { + assert.Eq(t, 3, nmath.MaxInt(2, 3)) + assert.Eq(t, 3, nmath.MaxInt(3, 2)) + + assert.Eq(t, int64(3), nmath.MaxI64(2, 3)) + assert.Eq(t, int64(3), nmath.MaxI64(3, 2)) + + assert.Eq(t, 3, nmath.Max[int](3, 2)) + assert.Eq(t, int64(3), nmath.Max[int64](3, 2)) + assert.Eq(t, int64(3), nmath.Max(int64(3), int64(2))) +} + +func TestSwapMaxInt(t *testing.T) { + x, y := nmath.SwapMax(2, 34) + assert.Eq(t, 34, x) + assert.Eq(t, 2, y) + + x, y = nmath.SwapMax(34, 2) + assert.Eq(t, 34, x) + assert.Eq(t, 2, y) + + x, y = nmath.SwapMin(2, 34) + assert.Eq(t, 2, x) + assert.Eq(t, 34, y) + + x, y = nmath.SwapMin(34, 2) + assert.Eq(t, 2, x) + assert.Eq(t, 34, y) + + x, y = nmath.SwapMaxInt(2, 34) + assert.Eq(t, 34, x) + assert.Eq(t, 2, y) + + x64, y64 := nmath.SwapMaxI64(2, 34) + assert.Eq(t, int64(34), x64) + assert.Eq(t, int64(2), y64) +} + +func TestOrElse(t *testing.T) { + assert.Eq(t, 23, nmath.OrElse(23, 21)) + assert.Eq(t, 21.3, nmath.OrElse[float64](0, 21.3)) +} diff --git a/nnet/util.go b/nnet/util.go new file mode 100644 index 0000000..4542320 --- /dev/null +++ b/nnet/util.go @@ -0,0 +1,60 @@ +package nnet + +import ( + "net" + "net/netip" +) + +// InternalIPOld get internal IP buy old logic +func InternalIPOld() (ip string) { + addrs, err := net.InterfaceAddrs() + if err != nil { + panic("Oops: " + err.Error()) + } + + for _, a := range addrs { + if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + // os.Stdout.WriteString(ipNet.IP.String() + "\n") + ip = ipNet.IP.String() + return + } + } + } + return +} + +// InternalIP get internal IP +func InternalIP() (ip string) { + addr := netip.IPv4Unspecified() + if addr.IsValid() { + return addr.String() + } + + addr = netip.IPv6Unspecified() + if addr.IsValid() { + return addr.String() + } + + return "" +} + +// InternalIPv4 get internal IPv4 +func InternalIPv4() (ip string) { + addr := netip.IPv4Unspecified() + + if addr.IsValid() { + return addr.String() + } + return "" +} + +// InternalIPv6 get internal IPv6 +func InternalIPv6() (ip string) { + addr := netip.IPv6Unspecified() + + if addr.IsValid() { + return addr.String() + } + return "" +} diff --git a/nrandom/bytes.go b/nrandom/bytes.go new file mode 100644 index 0000000..6c08c2b --- /dev/null +++ b/nrandom/bytes.go @@ -0,0 +1,20 @@ +package nrandom + +import ( + crand "crypto/rand" + "io" +) + +// RandBytes generate random byte slice. +func RandBytes(length int) []byte { + if length < 1 { + return []byte{} + } + b := make([]byte, length) + + if _, err := io.ReadFull(crand.Reader, b); err != nil { + return nil + } + + return b +} diff --git a/nrandom/const.go b/nrandom/const.go new file mode 100644 index 0000000..9d286b2 --- /dev/null +++ b/nrandom/const.go @@ -0,0 +1,14 @@ +package nrandom + +const ( + letterIdBits = 6 + letterIdMask = 1< 0 { + bs = append(bs, prefix...) + } + + // micro datatime + bs = nt.AppendFormat(bs, "20060102150405.000000") + bs[14+pl] = '0' + + // host + name, err := os.Hostname() + if err != nil { + name = "default" + } + c32 := crc32.ChecksumIEEE([]byte(name)) // eg: 4006367001 + bs = strconv.AppendUint(bs, uint64(c32%99), 10) + + // rand 10000 - 99999 + rs := rand.New(rand.NewSource(nt.UnixNano())) + bs = strconv.AppendInt(bs, 10000+rs.Int63n(89999), 10) + + return string(bs) +} diff --git a/nrandom/int.go b/nrandom/int.go new file mode 100644 index 0000000..dd3908c --- /dev/null +++ b/nrandom/int.go @@ -0,0 +1,34 @@ +package nrandom + +import ( + "math/rand" + "time" +) + +// RandInt generate random int between min and max, maybe min, not be max. +// +// Usage: +// +// RandInt(10, 99) +// RandInt(100, 999) +// RandInt(1000, 9999) +func RandInt(min, max int) int { + return RandIntWithSeed(min, max, time.Now().UnixNano()) +} + +// RandIntWithSeed return a random int at the [min, max) +// Usage: +// +// seed := time.Now().UnixNano() +// RandomIntWithSeed(1000, 9999, seed) +func RandIntWithSeed(min, max int, seed int64) int { + if min == max { + return min + } + if max < min { + min, max = max, min + } + r := rand.New(rand.NewSource(seed)) + + return r.Intn(max-min) + min +} diff --git a/nrandom/int_test.go b/nrandom/int_test.go new file mode 100644 index 0000000..6a7730a --- /dev/null +++ b/nrandom/int_test.go @@ -0,0 +1,24 @@ +package nrandom_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nrandom" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" + "time" +) + +func TestRandInt(t *testing.T) { + min, max := 1000, 9999 + for i := 0; i < 5; i++ { + val := nrandom.RandInt(min, max) + fmt.Println(val) + assert.True(t, val >= min) + assert.True(t, val <= max) + + seed := time.Now().UnixNano() + val = nrandom.RandIntWithSeed(min, max, seed) + assert.True(t, val >= min) + assert.True(t, val <= max) + } +} diff --git a/nrandom/snowflake/generator.go b/nrandom/snowflake/generator.go new file mode 100644 index 0000000..cfd545c --- /dev/null +++ b/nrandom/snowflake/generator.go @@ -0,0 +1,51 @@ +package snowflake + +import ( + "sync" + "time" +) + +type Generator interface { + // NextID 生成一个唯一ID + NextID() int64 +} + +const sleepDuration = 200 * time.Millisecond + +var ( + mutex sync.Mutex + snowflakeInstance Generator +) + +func GetDefaultInstance() Generator { + if snowflakeInstance == nil { + mutex.Lock() + snowflakeInstance = NewSnowWorkerOffset(nil) + time.Sleep(sleepDuration) + mutex.Unlock() + } + return snowflakeInstance +} + +func SetSnowflake(options *Options) { + snowflakeInstance = NewSnowflake(options) +} + +func NewSnowflake(options *Options) Generator { + mutex.Lock() + defer mutex.Unlock() + + var snowflakeInstance Generator + + switch options.Method { + case MethodSnowflakeOffset: + snowflakeInstance = NewSnowWorkerOffset(options) + time.Sleep(sleepDuration) + case MethodSnowflake: + fallthrough + default: + snowflakeInstance = NewSnowWorker(options) + } + + return snowflakeInstance +} diff --git a/nrandom/snowflake/options.go b/nrandom/snowflake/options.go new file mode 100644 index 0000000..75804f2 --- /dev/null +++ b/nrandom/snowflake/options.go @@ -0,0 +1,45 @@ +package snowflake + +import "time" + +type Method byte + +const ( + MethodSnowflakeOffset Method = iota + 1 + MethodSnowflake +) + +type Options struct { + Method Method // 雪花计算方法,(1-漂移算法|2-传统算法),默认1 + BaseTime int64 // 基础时间(ms单位),不能超过当前系统时间 + WorkerId uint16 // 机器码,必须由外部设定,最大值 2^WorkerIdBitLength-1 + WorkerIdBitLength byte // 机器码位长,默认值6,取值范围 [1, 15](要求:序列数位长+机器码位长不超过22) + SeqBitLength byte // 序列数位长,默认值6,取值范围 [3, 21](要求:序列数位长+机器码位长不超过22) + MaxSeqNumber uint32 // 最大序列数(含),设置范围 [MinSeqNumber, 2^SeqBitLength-1],默认值0,表示最大序列数取最大值(2^SeqBitLength-1]) + MinSeqNumber uint32 // 最小序列数(含),默认值5,取值范围 [5, MaxSeqNumber],每毫秒的前5个序列数对应编号0-4是保留位,其中1-4是时间回拨相应预留位,0是手工新值预留位 + TopOverCostCount uint32 // 最大漂移次数(含),默认2000,推荐范围500-10000(与计算能力有关) +} + +var defaultOptions = &Options{ + Method: MethodSnowflakeOffset, + BaseTime: time.Date(2022, time.February, 1, 0, 0, 0, 0, time.Local).UnixNano() / 1e6, // 2022-01-01 00:00:00 local + WorkerId: 1, + WorkerIdBitLength: 6, + SeqBitLength: 16, // 50w并发建议10 500w建议14-16 + MaxSeqNumber: 0, + MinSeqNumber: 5, + TopOverCostCount: 5000, // 500w并发建议5000-8000 +} + +func NewOptions(workerId uint16) *Options { + return &Options{ + WorkerId: workerId, + Method: defaultOptions.Method, + BaseTime: defaultOptions.BaseTime, + WorkerIdBitLength: defaultOptions.WorkerIdBitLength, + SeqBitLength: defaultOptions.SeqBitLength, + MaxSeqNumber: defaultOptions.MaxSeqNumber, + MinSeqNumber: defaultOptions.MinSeqNumber, + TopOverCostCount: defaultOptions.TopOverCostCount, + } +} diff --git a/nrandom/snowflake/snowflake.go b/nrandom/snowflake/snowflake.go new file mode 100644 index 0000000..e39aec4 --- /dev/null +++ b/nrandom/snowflake/snowflake.go @@ -0,0 +1,37 @@ +package snowflake + +import ( + "fmt" +) + +type SnowWorker struct { + *SnowWorkerOffset +} + +func NewSnowWorker(options *Options) *SnowWorker { + options.Method = 2 + return &SnowWorker{ + NewSnowWorkerOffset(options), + } +} + +func (s *SnowWorker) NextID() int64 { + s.Lock() + defer s.Unlock() + currentTimeTick := s.currentTimeTick() + if s.lastTimeTick == currentTimeTick { + s.currentSeqNumber++ + if s.currentSeqNumber > s.maxSeqNumber { + s.currentSeqNumber = s.minSeqNumber + currentTimeTick = s.nextTimeTick() + } + } else { + s.currentSeqNumber = s.minSeqNumber + } + if currentTimeTick < s.lastTimeTick { + fmt.Printf("Time error for %d milliseconds", s.lastTimeTick-currentTimeTick) + } + s.lastTimeTick = currentTimeTick + id := currentTimeTick< 22 { + workerIdBitLength = defaultOptions.WorkerIdBitLength + if workerIdBitLength+seqBitLength > 22 { + seqBitLength = defaultOptions.SeqBitLength + } + } + + minSeqNumber := options.MinSeqNumber + if minSeqNumber < 5 { + minSeqNumber = defaultOptions.MinSeqNumber + } + + maxSeqNumber := options.MaxSeqNumber + if maxSeqNumber <= 0 { + maxSeqNumber = (1 << seqBitLength) - 1 + } + + topOverCostCount := options.TopOverCostCount + if topOverCostCount == 0 { + topOverCostCount = defaultOptions.TopOverCostCount + } + + return &SnowWorkerOffset{ + baseTime: baseTime, + workerId: options.WorkerId, + workerIdBitLength: workerIdBitLength, + seqBitLength: seqBitLength, + maxSeqNumber: maxSeqNumber, + minSeqNumber: minSeqNumber, + topOverCostCount: topOverCostCount, + timestampShift: workerIdBitLength + seqBitLength, + currentSeqNumber: minSeqNumber, + + lastTimeTick: 0, + turnBackTimeTick: 0, + turnBackIndex: 0, + isOverCost: false, + overCostCountInOneTerm: 0, + } +} + +func (s *SnowWorkerOffset) NextID() int64 { + s.Lock() + defer s.Unlock() + if s.isOverCost { + return s.nextOverCost() + } else { + return s.nextNormal() + } +} + +func (s *SnowWorkerOffset) nextOverCost() int64 { + currentTimeTick := s.currentTimeTick() + if currentTimeTick > s.lastTimeTick { + s.lastTimeTick = currentTimeTick + s.currentSeqNumber = s.minSeqNumber + s.isOverCost = false + s.overCostCountInOneTerm = 0 + return s.calcId(s.lastTimeTick) + } + if s.overCostCountInOneTerm >= s.topOverCostCount { + s.lastTimeTick = s.nextTimeTick() + s.currentSeqNumber = s.minSeqNumber + s.isOverCost = false + s.overCostCountInOneTerm = 0 + return s.calcId(s.lastTimeTick) + } + if s.currentSeqNumber > s.maxSeqNumber { + s.lastTimeTick++ + s.currentSeqNumber = s.minSeqNumber + s.isOverCost = true + s.overCostCountInOneTerm++ + return s.calcId(s.lastTimeTick) + } + return s.calcId(s.lastTimeTick) +} + +func (s *SnowWorkerOffset) nextNormal() int64 { + currentTimeTick := s.currentTimeTick() + if currentTimeTick < s.lastTimeTick { + if s.turnBackTimeTick < 1 { + s.turnBackTimeTick = s.lastTimeTick - 1 + s.turnBackIndex++ + // 每毫秒序列数的前5位是预留位,0用于手工新值,1-4是时间回拨次序 + // 支持4次回拨次序(避免回拨重叠导致ID重复),可无限次回拨(次序循环使用)。 + if s.turnBackIndex > 4 { + s.turnBackIndex = 1 + } + } + return s.calcTurnBackId(s.turnBackTimeTick) + } + // 时间追平时,turnBackTimeTick 清零 + if s.turnBackTimeTick > 0 { + s.turnBackTimeTick = 0 + } + if currentTimeTick > s.lastTimeTick { + s.lastTimeTick = currentTimeTick + s.currentSeqNumber = s.minSeqNumber + return s.calcId(s.lastTimeTick) + } + if s.currentSeqNumber > s.maxSeqNumber { + s.lastTimeTick++ + s.currentSeqNumber = s.minSeqNumber + s.isOverCost = true + s.overCostCountInOneTerm = 1 + return s.calcId(s.lastTimeTick) + } + return s.calcId(s.lastTimeTick) +} + +func (s *SnowWorkerOffset) calcId(timeTick int64) int64 { + id := timeTick<= 0; { + if remain == 0 { + cache, remain = r.Int63(), letterIdMax + } + if idx := int(cache & letterIdMask); idx < len(letters) { + b[i] = letters[idx] + i-- + } + cache >>= letterIdBits + remain-- + } + return *(*string)(unsafe.Pointer(&b)) +} + +// RandString generate random string of specified length. +func RandString(length int) string { + return Random(Letters, length) +} + +// RandStrUpper generate a random upper case string. +func RandStrUpper(length int) string { + return Random(UpperLetters, length) +} + +// RandStrLower generate a random lower case string. +func RandStrLower(length int) string { + return Random(LowerLetters, length) +} + +// RandNumeralStr generate a random numeral string of specified length. +func RandNumeralStr(length int) string { + return Random(Numeral, length) +} + +// RandNumeralOrLetter generate a random numeral or letter string. +func RandNumeralOrLetter(length int) string { + return Random(Numeral+Letters, length) +} diff --git a/nrandom/uuid.go b/nrandom/uuid.go new file mode 100644 index 0000000..a9ec090 --- /dev/null +++ b/nrandom/uuid.go @@ -0,0 +1,73 @@ +package nrandom + +import ( + "git.noahlan.cn/noahlan/ntool/nrandom/snowflake" + "github.com/gofrs/uuid/v5" + "strconv" +) + +// SnowflakeId returns a snowflake uuid by default instance. +func SnowflakeId() int64 { + return snowflake.GetDefaultInstance().NextID() +} + +// SnowflakeIdStr returns a snowflake uuid by default instance. +// convert to string +func SnowflakeIdStr() string { + return strconv.FormatInt(SnowflakeId(), 10) +} + +// UUIDV4 returns a canonical RFC-4122 string representation of the UUID: +// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx +func UUIDV4() string { + return NewUUIDV4().String() +} + +// NewUUIDV7 returns a new UUID +func NewUUIDV7() uuid.UUID { + id, err := uuid.NewV7() + if err != nil { + return uuid.UUID{} + } + return id +} + +// NewUUIDV6 returns a new UUID +func NewUUIDV6() uuid.UUID { + id, err := uuid.NewV6() + if err != nil { + return uuid.UUID{} + } + return id +} + +// NewUUIDV4 returns a new UUID +func NewUUIDV4() uuid.UUID { + id, err := uuid.NewV4() + if err != nil { + return uuid.UUID{} + } + return id +} + +// ParseUUIDSlice parses the UUID string slice to UUID slice +func ParseUUIDSlice(ids []string) []uuid.UUID { + var result []uuid.UUID + for _, v := range ids { + p, err := uuid.FromString(v) + if err != nil { + return nil + } + result = append(result, p) + } + return result +} + +// ParseUUIDString parses UUID string to UUID type +func ParseUUIDString(id string) uuid.UUID { + result, err := uuid.FromString(id) + if err != nil { + return uuid.UUID{} + } + return result +} diff --git a/nrandom/uuid_test.go b/nrandom/uuid_test.go new file mode 100644 index 0000000..7ec82e3 --- /dev/null +++ b/nrandom/uuid_test.go @@ -0,0 +1,62 @@ +package nrandom_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nrandom" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func baseTestRandom[T ~string | ndef.XIntOrFloat](t *testing.T, count int, genFn func() T, singleCheck func(item T)) { + acts := make([]T, count) + for i := 0; i < count; i++ { + act := genFn() + if i < 2 { + fmt.Println(act) + } + acts[i] = act + singleCheck(act) + } + uniqueLen := len(narr.Unique(acts)) + assert.True(t, len(acts) == uniqueLen) +} + +const TestCount = 10000 + +func TestSnowflakeId(t *testing.T) { + baseTestRandom(t, TestCount, nrandom.SnowflakeId, func(item int64) { + assert.True(t, item != 0) + }) +} + +func TestUUIDV4(t *testing.T) { + baseTestRandom(t, TestCount, nrandom.UUIDV4, func(item string) { + assert.NotEmpty(t, item) + }) +} + +func TestNewUUIDV4(t *testing.T) { + baseTestRandom(t, TestCount, func() string { + return nrandom.NewUUIDV4().String() + }, func(item string) { + assert.NotEmpty(t, item) + }) +} + +func TestNewUUIDV6(t *testing.T) { + baseTestRandom(t, TestCount, func() string { + return nrandom.NewUUIDV6().String() + }, func(item string) { + assert.NotEmpty(t, item) + }) +} + +func TestNewUUIDV7(t *testing.T) { + baseTestRandom(t, TestCount, func() string { + return nrandom.NewUUIDV7().String() + }, func(item string) { + assert.NotEmpty(t, item) + }) +} diff --git a/nreflect/check.go b/nreflect/check.go new file mode 100644 index 0000000..bfafd09 --- /dev/null +++ b/nreflect/check.go @@ -0,0 +1,158 @@ +package nreflect + +import ( + "bytes" + "reflect" +) + +// HasChild type check. eg: array, slice, map, struct +func HasChild(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct: + return true + } + return false +} + +// IsArrayOrSlice check. eg: array, slice +func IsArrayOrSlice(k reflect.Kind) bool { + return k == reflect.Slice || k == reflect.Array +} + +// IsSimpleKind kind in: string, bool, intX, uintX, floatX +func IsSimpleKind(k reflect.Kind) bool { + if reflect.String == k { + return true + } + return k > reflect.Invalid && k <= reflect.Float64 +} + +// IsAnyInt check is intX or uintX type +func IsAnyInt(k reflect.Kind) bool { + return k >= reflect.Int && k <= reflect.Uintptr +} + +// IsIntx check is intX or uintX type +func IsIntx(k reflect.Kind) bool { + return k >= reflect.Int && k <= reflect.Int64 +} + +// IsUintX check is intX or uintX type +func IsUintX(k reflect.Kind) bool { + return k >= reflect.Uint && k <= reflect.Uintptr +} + +// IsNil reflect value +func IsNil(v reflect.Value) bool { + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return v.IsNil() + default: + return false + } +} + +// IsFunc value +func IsFunc(val any) bool { + if val == nil { + return false + } + return reflect.TypeOf(val).Kind() == reflect.Func +} + +// IsEqual determines if two objects are considered equal. +// +// TIP: cannot compare function type +func IsEqual(src, dst any) bool { + if src == nil || dst == nil { + return src == dst + } + + bs1, ok := src.([]byte) + if !ok { + return reflect.DeepEqual(src, dst) + } + + bs2, ok := dst.([]byte) + if !ok { + return false + } + + if bs1 == nil || bs2 == nil { + return bs1 == nil && bs2 == nil + } + return bytes.Equal(bs1, bs2) +} + +// IsEqualValues determines if two objects or two object-values are considered equal. +// +// TIP: cannot compare function type +func IsEqualValues(src, dst any) bool { + if IsEqual(src, dst) { + return true + } + dstType := reflect.TypeOf(dst) + if dstType == nil { + return false + } + srcValue := reflect.ValueOf(src) + if srcValue.IsValid() && srcValue.Type().ConvertibleTo(dstType) { + // Attempt comparison after type conversion + return reflect.DeepEqual(srcValue.Convert(dstType).Interface(), dst) + } + + return false +} + +// IsEmpty reflect value check +func IsEmpty(v reflect.Value) bool { + switch v.Kind() { + case reflect.Invalid: + return true + case reflect.String, reflect.Array: + return v.Len() == 0 + case reflect.Map, reflect.Slice: + return v.Len() == 0 || v.IsNil() + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr, reflect.Func: + return v.IsNil() + } + + return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) +} + +// IsEmptyValue reflect value check. +// Difference the IsEmpty(), if value is ptr, will check real elem. +// +// From src/pkg/encoding/json/encode.go. +func IsEmptyValue(v reflect.Value) bool { + switch v.Kind() { + case reflect.Array, reflect.Map, reflect.Slice, reflect.String: + return v.Len() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Interface, reflect.Ptr: + if v.IsNil() { + return true + } + return IsEmptyValue(v.Elem()) + case reflect.Func: + return v.IsNil() + case reflect.Invalid: + return true + } + return false +} diff --git a/nreflect/convert.go b/nreflect/convert.go new file mode 100644 index 0000000..f883333 --- /dev/null +++ b/nreflect/convert.go @@ -0,0 +1,195 @@ +package nreflect + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/internal/convert" + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" + "strconv" +) + +// BaseTypeVal convert custom type or intX,uintX,floatX to generic base type. +// +// intX/unitX => int64 +// floatX => float64 +// string => string +// +// returns int64,string,float or error +func BaseTypeVal(v reflect.Value) (value any, err error) { + v = reflect.Indirect(v) + + switch v.Kind() { + case reflect.String: + value = v.String() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + value = v.Int() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + value = int64(v.Uint()) // always return int64 + case reflect.Float32, reflect.Float64: + value = v.Float() + default: + err = ndef.ErrConvType + } + return +} + +// ValueByType create reflect.Value by give reflect.Type +func ValueByType(val any, typ reflect.Type) (rv reflect.Value, err error) { + // handle kind: string, bool, intX, uintX, floatX + if typ.Kind() == reflect.String || typ.Kind() <= reflect.Float64 { + return ValueByKind(val, typ.Kind()) + } + + newRv := reflect.ValueOf(val) + + // try auto convert slice type + if IsArrayOrSlice(newRv.Kind()) && IsArrayOrSlice(typ.Kind()) { + return ConvSlice(newRv, typ.Elem()) + } + + // check type. like map + if newRv.Type() == typ { + return newRv, nil + } + + err = ndef.ErrConvType + return +} + +// ValueByKind create reflect.Value by give reflect.Kind +// +// TIPs: +// +// Only support kind: string, bool, intX, uintX, floatX +func ValueByKind(val any, kind reflect.Kind) (rv reflect.Value, err error) { + switch kind { + case reflect.Int: + if dstV, err1 := nmath.ToInt(val); err1 == nil { + rv = reflect.ValueOf(dstV) + } + case reflect.Int8: + if dstV, err1 := nmath.ToInt(val); err1 == nil { + rv = reflect.ValueOf(int8(dstV)) + } + case reflect.Int16: + if dstV, err1 := nmath.ToInt(val); err1 == nil { + rv = reflect.ValueOf(int16(dstV)) + } + case reflect.Int32: + if dstV, err1 := nmath.ToInt(val); err1 == nil { + rv = reflect.ValueOf(int32(dstV)) + } + case reflect.Int64: + if dstV, err1 := nmath.ToInt64(val); err1 == nil { + rv = reflect.ValueOf(dstV) + } + case reflect.Uint: + if dstV, err1 := nmath.ToUint(val); err1 == nil { + rv = reflect.ValueOf(uint(dstV)) + } + case reflect.Uint8: + if dstV, err1 := nmath.ToUint(val); err1 == nil { + rv = reflect.ValueOf(uint8(dstV)) + } + case reflect.Uint16: + if dstV, err1 := nmath.ToUint(val); err1 == nil { + rv = reflect.ValueOf(uint16(dstV)) + } + case reflect.Uint32: + if dstV, err1 := nmath.ToUint(val); err1 == nil { + rv = reflect.ValueOf(uint32(dstV)) + } + case reflect.Uint64: + if dstV, err1 := nmath.ToUint(val); err1 == nil { + rv = reflect.ValueOf(dstV) + } + case reflect.Float32: + if dstV, err1 := nmath.ToFloat(val); err1 == nil { + rv = reflect.ValueOf(float32(dstV)) + } + case reflect.Float64: + if dstV, err1 := nmath.ToFloat(val); err1 == nil { + rv = reflect.ValueOf(dstV) + } + case reflect.String: + if dstV, err1 := nstr.ToString(val); err1 == nil { + rv = reflect.ValueOf(dstV) + } + case reflect.Bool: + if bl, err := convert.ToBool(val); err == nil { + rv = reflect.ValueOf(bl) + } + } + + if !rv.IsValid() { + err = ndef.ErrConvType + } + return +} + +// ConvSlice make new type slice from old slice, will auto convert element type. +// +// TIPs: +// +// Only support kind: string, bool, intX, uintX, floatX +func ConvSlice(oldSlRv reflect.Value, newElemTyp reflect.Type) (rv reflect.Value, err error) { + if !IsArrayOrSlice(oldSlRv.Kind()) { + panic("only allow array or slice type value") + } + + // do not need convert type + if oldSlRv.Type().Elem() == newElemTyp { + return oldSlRv, nil + } + + newSlTyp := reflect.SliceOf(newElemTyp) + newSlRv := reflect.MakeSlice(newSlTyp, 0, 0) + for i := 0; i < oldSlRv.Len(); i++ { + newElemV, err := ValueByKind(oldSlRv.Index(i).Interface(), newElemTyp.Kind()) + if err != nil { + return reflect.Value{}, err + } + + newSlRv = reflect.Append(newSlRv, newElemV) + } + return newSlRv, nil +} + +// String convert +func String(rv reflect.Value) string { + s, _ := ValToString(rv, false) + return s +} + +// ToString convert +func ToString(rv reflect.Value) (str string, err error) { + return ValToString(rv, true) +} + +// ValToString convert handle +func ValToString(rv reflect.Value, defaultAsErr bool) (str string, err error) { + rv = Indirect(rv) + switch rv.Kind() { + case reflect.Invalid: + str = "" + case reflect.Bool: + str = strconv.FormatBool(rv.Bool()) + case reflect.String: + str = rv.String() + case reflect.Float32, reflect.Float64: + str = strconv.FormatFloat(rv.Float(), 'f', -1, 64) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str = strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + str = strconv.FormatUint(rv.Uint(), 10) + default: + if defaultAsErr { + err = ndef.ErrConvType + } else { + str = fmt.Sprint(rv.Interface()) + } + } + return +} diff --git a/nreflect/util.go b/nreflect/util.go new file mode 100644 index 0000000..f26dd0d --- /dev/null +++ b/nreflect/util.go @@ -0,0 +1,207 @@ +package nreflect + +import ( + "fmt" + "reflect" + "strconv" + "unsafe" +) + +// Elem returns the value that the interface v contains +// or that the pointer v points to. +func Elem(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + return v.Elem() + } + + // otherwise, will return self + return v +} + +// Indirect like reflect.Indirect(), but can also indirect reflect.Interface +func Indirect(v reflect.Value) reflect.Value { + if v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + return v.Elem() + } + + // otherwise, will return self + return v +} + +// Len get reflect value length +func Len(v reflect.Value) int { + v = reflect.Indirect(v) + + // (u)int use width. + switch v.Kind() { + case reflect.String: + return len([]rune(v.String())) + case reflect.Map, reflect.Array, reflect.Chan, reflect.Slice: + return v.Len() + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return len(strconv.FormatInt(int64(v.Uint()), 10)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return len(strconv.FormatInt(v.Int(), 10)) + case reflect.Float32, reflect.Float64: + return len(fmt.Sprint(v.Interface())) + } + + // cannot get length + return -1 +} + +// SliceSubKind get sub-elem kind of the array, slice, variadic-var. alias SliceElemKind() +func SliceSubKind(typ reflect.Type) reflect.Kind { + return SliceElemKind(typ) +} + +// SliceElemKind get sub-elem kind of the array, slice, variadic-var. +// +// Usage: +// +// SliceElemKind(reflect.TypeOf([]string{"abc"})) // reflect.String +func SliceElemKind(typ reflect.Type) reflect.Kind { + if typ.Kind() == reflect.Slice || typ.Kind() == reflect.Array { + return typ.Elem().Kind() + } + return reflect.Invalid +} + +// UnexportedValue quickly get unexported value by reflect.Value +// +// NOTE: this method is unsafe, use it carefully. +// should ensure rv is addressable by field.CanAddr() +// +// refer: https://stackoverflow.com/questions/42664837/how-to-access-unexported-struct-fields +func UnexportedValue(rv reflect.Value) any { + if rv.CanAddr() { + // create new value from addr, now can be read and set. + return reflect.NewAt(rv.Type(), unsafe.Pointer(rv.UnsafeAddr())).Elem().Interface() + } + + // If the rv is not addressable this trick won't work, but you can create an addressable copy like this + rs2 := reflect.New(rv.Type()).Elem() + rs2.Set(rv) + rv = rs2.Field(0) + rv = reflect.NewAt(rv.Type(), unsafe.Pointer(rv.UnsafeAddr())).Elem() + // Now rv can be read. TIP: Setting will succeed but only affects the temporary copy. + return rv.Interface() +} + +// SetUnexportedValue quickly set unexported field value by reflect +// +// NOTE: this method is unsafe, use it carefully. +// should ensure rv is addressable by field.CanAddr() +func SetUnexportedValue(rv reflect.Value, value any) { + reflect.NewAt(rv.Type(), unsafe.Pointer(rv.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +// SetValue to a `reflect.Value`. will auto convert type if needed. +func SetValue(rv reflect.Value, val any) error { + // get real type of the ptr value + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + elemTyp := rv.Type().Elem() + rv.Set(reflect.New(elemTyp)) + } + + // use elem for set value + rv = reflect.Indirect(rv) + } + + rv1, err := ValueByType(val, rv.Type()) + if err == nil { + rv.Set(rv1) + } + return err +} + +// SetRValue to a `reflect.Value`. will direct set value without convert type. +func SetRValue(rv, val reflect.Value) { + if rv.Kind() == reflect.Ptr { + if rv.IsNil() { + elemTyp := rv.Type().Elem() + rv.Set(reflect.New(elemTyp)) + } + rv = reflect.Indirect(rv) + } + + rv.Set(val) +} + +// EachMap process any map data +func EachMap(mp reflect.Value, fn func(key, val reflect.Value)) { + if fn == nil { + return + } + if mp.Kind() != reflect.Map { + panic("only allow map value data") + } + + for _, key := range mp.MapKeys() { + fn(key, mp.MapIndex(key)) + } +} + +// EachStrAnyMap process any map data as string key and any value +func EachStrAnyMap(mp reflect.Value, fn func(key string, val any)) { + EachMap(mp, func(key, val reflect.Value) { + fn(String(key), val.Interface()) + }) +} + +// FlatFunc custom collect handle func +type FlatFunc func(path string, val reflect.Value) + +// FlatMap process tree map to flat key-value map. +// +// Examples: +// +// {"top": {"sub": "value", "sub2": "value2"} } +// -> +// {"top.sub": "value", "top.sub2": "value2" } +func FlatMap(rv reflect.Value, fn FlatFunc) { + if fn == nil { + return + } + + if rv.Kind() != reflect.Map { + panic("only allow flat map data") + } + flatMap(rv, fn, "") +} + +func flatMap(rv reflect.Value, fn FlatFunc, parent string) { + for _, key := range rv.MapKeys() { + path := String(key) + if parent != "" { + path = parent + "." + path + } + + fv := Indirect(rv.MapIndex(key)) + switch fv.Kind() { + case reflect.Map: + flatMap(fv, fn, path) + case reflect.Array, reflect.Slice: + flatSlice(fv, fn, path) + default: + fn(path, fv) + } + } +} + +func flatSlice(rv reflect.Value, fn FlatFunc, parent string) { + for i := 0; i < rv.Len(); i++ { + path := parent + "[" + strconv.Itoa(i) + "]" + fv := Indirect(rv.Index(i)) + + switch fv.Kind() { + case reflect.Map: + flatMap(fv, fn, path) + case reflect.Array, reflect.Slice: + flatSlice(fv, fn, path) + default: + fn(path, fv) + } + } +} diff --git a/nstd/chan.go b/nstd/chan.go new file mode 100644 index 0000000..d947895 --- /dev/null +++ b/nstd/chan.go @@ -0,0 +1,63 @@ +package nstd + +import ( + "context" + "fmt" + "io" + "os" + "os/signal" + "syscall" +) + +// WaitCloseSignals for some huang program. +func WaitCloseSignals(closer io.Closer) error { + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) + <-signals + + return closer.Close() +} + +// Go is a basic promise implementation: it wraps calls a function in a goroutine +// and returns a channel which will later return the function's return value. +func Go(f func() error) error { + ch := make(chan error) + go func() { + ch <- f() + }() + + return <-ch +} + +// SignalHandler returns an actor, i.e. an execute and interrupt func, that +// terminates with SignalError when the process receives one of the provided +// signals, or the parent context is canceled. +// +// from https://github.com/oklog/run/blob/master/actors.go +func SignalHandler(ctx context.Context, signals ...os.Signal) (execute func() error, interrupt func(error)) { + ctx, cancel := context.WithCancel(ctx) + return func() error { + c := make(chan os.Signal, 1) + signal.Notify(c, signals...) + defer signal.Stop(c) + select { + case sig := <-c: + return SignalError{Signal: sig} + case <-ctx.Done(): + return ctx.Err() + } + }, func(error) { + cancel() + } +} + +// SignalError is returned by the signal handler's execute function +// when it terminates due to a received signal. +type SignalError struct { + Signal os.Signal +} + +// Error implements the error interface. +func (e SignalError) Error() string { + return fmt.Sprintf("received signal %s", e.Signal) +} diff --git a/nstd/check.go b/nstd/check.go new file mode 100644 index 0000000..dafe494 --- /dev/null +++ b/nstd/check.go @@ -0,0 +1,120 @@ +package nstd + +import ( + "git.noahlan.cn/noahlan/ntool/nreflect" + "reflect" + "strings" +) + +// IsNil value check +func IsNil(v any) bool { + if v == nil { + return true + } + return nreflect.IsNil(reflect.ValueOf(v)) +} + +// IsEmpty value check +func IsEmpty(v any) bool { + if v == nil { + return true + } + return nreflect.IsEmpty(reflect.ValueOf(v)) +} + +// IsFunc value +func IsFunc(val any) bool { + if val == nil { + return false + } + return reflect.TypeOf(val).Kind() == reflect.Func +} + +// IsEqual determines if two objects are considered equal. +// +// TIP: cannot compare function type +func IsEqual(src, dst any) bool { + if src == nil || dst == nil { + return src == dst + } + + // cannot compare function type + if IsFunc(src) || IsFunc(dst) { + return false + } + return nreflect.IsEqual(src, dst) +} + +// Contains try loop over the data check if the data includes the element. +// alias of the IsContains +// +// TIP: only support types: string, map, array, slice +// +// map - check key exists +// string - check sub-string exists +// array,slice - check sub-element exists +func Contains(data, elem any) bool { + _, found := CheckContains(data, elem) + return found +} + +// IsContains try loop over the data check if the data includes the element. +// +// TIP: only support types: string, map, array, slice +// +// map - check key exists +// string - check sub-string exists +// array,slice - check sub-element exists +func IsContains(data, elem any) bool { + _, found := CheckContains(data, elem) + return found +} + +// CheckContains try loop over the data check if the data includes the element. +// +// TIP: only support types: string, map, array, slice +// +// map - check key exists +// string - check sub-string exists +// array,slice - check sub-element exists +// +// return (false, false) if impossible. +// return (true, false) if element was not found. +// return (true, true) if element was found. +func CheckContains(data, elem any) (valid, found bool) { + dataRv := reflect.ValueOf(data) + dataRt := reflect.TypeOf(data) + if dataRt == nil { + return false, false + } + + dataKind := dataRt.Kind() + + // string + if dataKind == reflect.String { + return true, strings.Contains(dataRv.String(), reflect.ValueOf(elem).String()) + } + + // map + if dataKind == reflect.Map { + mapKeys := dataRv.MapKeys() + for i := 0; i < len(mapKeys); i++ { + if nreflect.IsEqual(mapKeys[i].Interface(), elem) { + return true, true + } + } + return true, false + } + + // array, slice - other return false + if dataKind != reflect.Slice && dataKind != reflect.Array { + return false, false + } + + for i := 0; i < dataRv.Len(); i++ { + if nreflect.IsEqual(dataRv.Index(i).Interface(), elem) { + return true, true + } + } + return true, false +} diff --git a/nstd/gofunc.go b/nstd/gofunc.go new file mode 100644 index 0000000..e5dae68 --- /dev/null +++ b/nstd/gofunc.go @@ -0,0 +1,90 @@ +package nstd + +import ( + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" + "runtime" + "strings" +) + +// FullFcName struct. +type FullFcName struct { + // FullName eg: git.noahlan.cn/noahlan/ntool/nstd.IsNil + FullName string + pkgPath string + pkgName string + funcName string +} + +// Parse the full func name. +func (ffn *FullFcName) Parse() { + if ffn.funcName != "" { + return + } + + i := strings.LastIndex(ffn.FullName, "/") + + ffn.pkgPath = ffn.FullName[:i+1] + // spilt get pkg and func name + ffn.pkgName, ffn.funcName = nstr.MustCut(ffn.FullName[i+1:], ".") + + ffn.pkgPath += ffn.pkgName +} + +// PkgPath string get. eg: git.noahlan.cn/noahlan/ntool/nstd +func (ffn *FullFcName) PkgPath() string { + ffn.Parse() + return ffn.pkgPath +} + +// PkgName string get. eg: nstd +func (ffn *FullFcName) PkgName() string { + ffn.Parse() + return ffn.pkgName +} + +// FuncName get short func name. eg: IsNil +func (ffn *FullFcName) FuncName() string { + ffn.Parse() + return ffn.funcName +} + +// String get full func name string. +func (ffn *FullFcName) String() string { + return ffn.FullName +} + +// FuncName get full func name, contains pkg path. +// +// eg: +// +// // OUTPUT: git.noahlan.cn/noahlan/ntool/nstd.IsNil +// nstd.FuncName(nstd.PkgName) +func FuncName(fn any) string { + return runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() +} + +// CutFuncName get pkg path and short func name +func CutFuncName(fullFcName string) (pkgPath, shortFnName string) { + ffn := FullFcName{FullName: fullFcName} + return ffn.PkgPath(), ffn.FuncName() +} + +// PkgName get current package name +// +// Usage: +// +// fullFcName := nstd.FuncName(fn) +// pgkName := nstd.PkgName(fullFcName) +func PkgName(fullFcName string) string { + for { + lastPeriod := strings.LastIndex(fullFcName, ".") + lastSlash := strings.LastIndex(fullFcName, "/") + if lastPeriod > lastSlash { + fullFcName = fullFcName[:lastPeriod] + } else { + break + } + } + return fullFcName +} diff --git a/nstd/io/writer.go b/nstd/io/writer.go new file mode 100644 index 0000000..53e30a8 --- /dev/null +++ b/nstd/io/writer.go @@ -0,0 +1,20 @@ +package io + +import ( + "fmt" + "io" +) + +// ByteStringWriter interface +type ByteStringWriter interface { + io.Writer + io.ByteWriter + io.StringWriter + fmt.Stringer +} + +// StringWriteStringer interface +type StringWriteStringer interface { + io.StringWriter + fmt.Stringer +} diff --git a/nstd/io/writer_wrapper.go b/nstd/io/writer_wrapper.go new file mode 100644 index 0000000..f209441 --- /dev/null +++ b/nstd/io/writer_wrapper.go @@ -0,0 +1,48 @@ +package io + +import ( + "fmt" + "io" +) + +// WriteWrapper warp io.Writer support more operate methods. +type WriteWrapper struct { + Out io.Writer +} + +// NewWriteWrapper instance +func NewWriteWrapper(w io.Writer) *WriteWrapper { + return &WriteWrapper{Out: w} +} + +// Write bytes data +func (w *WriteWrapper) Write(p []byte) (n int, err error) { + return w.Out.Write(p) +} + +// Writef data to output +func (w *WriteWrapper) Writef(tpl string, vs ...any) (n int, err error) { + return fmt.Fprintf(w.Out, tpl, vs...) +} + +// WriteByte data +func (w *WriteWrapper) WriteByte(c byte) error { + _, err := w.Out.Write([]byte{c}) + return err +} + +// WriteString data +func (w *WriteWrapper) WriteString(s string) (n int, err error) { + if sw, ok := w.Out.(io.StringWriter); ok { + return sw.WriteString(s) + } + return w.Out.Write([]byte(s)) +} + +// String get write data string +func (w *WriteWrapper) String() string { + if sw, ok := w.Out.(fmt.Stringer); ok { + return sw.String() + } + return "" +} diff --git a/nstd/tea/tea.go b/nstd/tea/tea.go new file mode 100644 index 0000000..ded1642 --- /dev/null +++ b/nstd/tea/tea.go @@ -0,0 +1,491 @@ +package tea + +func String(a string) *string { + return &a +} + +func StringValue(a *string) string { + if a == nil { + return "" + } + return *a +} + +func Int(a int) *int { + return &a +} + +func IntValue(a *int) int { + if a == nil { + return 0 + } + return *a +} + +func Int8(a int8) *int8 { + return &a +} + +func Int8Value(a *int8) int8 { + if a == nil { + return 0 + } + return *a +} + +func Int16(a int16) *int16 { + return &a +} + +func Int16Value(a *int16) int16 { + if a == nil { + return 0 + } + return *a +} + +func Int32(a int32) *int32 { + return &a +} + +func Int32Value(a *int32) int32 { + if a == nil { + return 0 + } + return *a +} + +func Int64(a int64) *int64 { + return &a +} + +func Int64Value(a *int64) int64 { + if a == nil { + return 0 + } + return *a +} + +func Bool(a bool) *bool { + return &a +} + +func BoolValue(a *bool) bool { + if a == nil { + return false + } + return *a +} + +func Uint(a uint) *uint { + return &a +} + +func UintValue(a *uint) uint { + if a == nil { + return 0 + } + return *a +} + +func Uint8(a uint8) *uint8 { + return &a +} + +func Uint8Value(a *uint8) uint8 { + if a == nil { + return 0 + } + return *a +} + +func Uint16(a uint16) *uint16 { + return &a +} + +func Uint16Value(a *uint16) uint16 { + if a == nil { + return 0 + } + return *a +} + +func Uint32(a uint32) *uint32 { + return &a +} + +func Uint32Value(a *uint32) uint32 { + if a == nil { + return 0 + } + return *a +} + +func Uint64(a uint64) *uint64 { + return &a +} + +func Uint64Value(a *uint64) uint64 { + if a == nil { + return 0 + } + return *a +} + +func Float32(a float32) *float32 { + return &a +} + +func Float32Value(a *float32) float32 { + if a == nil { + return 0 + } + return *a +} + +func Float64(a float64) *float64 { + return &a +} + +func Float64Value(a *float64) float64 { + if a == nil { + return 0 + } + return *a +} + +func IntSlice(a []int) []*int { + if a == nil { + return nil + } + res := make([]*int, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func IntValueSlice(a []*int) []int { + if a == nil { + return nil + } + res := make([]int, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Int8Slice(a []int8) []*int8 { + if a == nil { + return nil + } + res := make([]*int8, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Int8ValueSlice(a []*int8) []int8 { + if a == nil { + return nil + } + res := make([]int8, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Int16Slice(a []int16) []*int16 { + if a == nil { + return nil + } + res := make([]*int16, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Int16ValueSlice(a []*int16) []int16 { + if a == nil { + return nil + } + res := make([]int16, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Int32Slice(a []int32) []*int32 { + if a == nil { + return nil + } + res := make([]*int32, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Int32ValueSlice(a []*int32) []int32 { + if a == nil { + return nil + } + res := make([]int32, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Int64Slice(a []int64) []*int64 { + if a == nil { + return nil + } + res := make([]*int64, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Int64ValueSlice(a []*int64) []int64 { + if a == nil { + return nil + } + res := make([]int64, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func UintSlice(a []uint) []*uint { + if a == nil { + return nil + } + res := make([]*uint, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func UintValueSlice(a []*uint) []uint { + if a == nil { + return nil + } + res := make([]uint, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Uint8Slice(a []uint8) []*uint8 { + if a == nil { + return nil + } + res := make([]*uint8, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Uint8ValueSlice(a []*uint8) []uint8 { + if a == nil { + return nil + } + res := make([]uint8, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Uint16Slice(a []uint16) []*uint16 { + if a == nil { + return nil + } + res := make([]*uint16, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Uint16ValueSlice(a []*uint16) []uint16 { + if a == nil { + return nil + } + res := make([]uint16, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Uint32Slice(a []uint32) []*uint32 { + if a == nil { + return nil + } + res := make([]*uint32, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Uint32ValueSlice(a []*uint32) []uint32 { + if a == nil { + return nil + } + res := make([]uint32, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Uint64Slice(a []uint64) []*uint64 { + if a == nil { + return nil + } + res := make([]*uint64, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Uint64ValueSlice(a []*uint64) []uint64 { + if a == nil { + return nil + } + res := make([]uint64, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Float32Slice(a []float32) []*float32 { + if a == nil { + return nil + } + res := make([]*float32, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Float32ValueSlice(a []*float32) []float32 { + if a == nil { + return nil + } + res := make([]float32, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func Float64Slice(a []float64) []*float64 { + if a == nil { + return nil + } + res := make([]*float64, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func Float64ValueSlice(a []*float64) []float64 { + if a == nil { + return nil + } + res := make([]float64, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func StringSlice(a []string) []*string { + if a == nil { + return nil + } + res := make([]*string, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func StringSliceValue(a []*string) []string { + if a == nil { + return nil + } + res := make([]string, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} + +func BoolSlice(a []bool) []*bool { + if a == nil { + return nil + } + res := make([]*bool, len(a)) + for i := 0; i < len(a); i++ { + res[i] = &a[i] + } + return res +} + +func BoolSliceValue(a []*bool) []bool { + if a == nil { + return nil + } + res := make([]bool, len(a)) + for i := 0; i < len(a); i++ { + if a[i] != nil { + res[i] = *a[i] + } + } + return res +} diff --git a/nstr/ac/README.md b/nstr/ac/README.md new file mode 100644 index 0000000..6ff5799 --- /dev/null +++ b/nstr/ac/README.md @@ -0,0 +1,85 @@ +# aho-corasick +Efficient string matching in Golang via the aho-corasick algorithm. + +x20 faster than https://github.com/cloudflare/ahocorasick and x3 faster than https://github.com/anknown/ahocorasick + +Memory consuption is a eigth of https://github.com/cloudflare/ahocorasick and half of https://github.com/anknown/ahocorasick + +This library is heavily inspired by https://github.com/BurntSushi/aho-corasick + +## Usage + +```bash +go get -u github.com/petar-dambovaliev/aho-corasick +``` + +```go +import ( + ahocorasick "github.com/petar-dambovaliev/aho-corasick" +) +builder := ahocorasick.NewAhoCorasickBuilder(Opts{ + AsciiCaseInsensitive: true, + MatchOnlyWholeWords: true, + MatchKind: LeftMostLongestMatch, + DFA: true, +}) + +ac := builder.Build([]string{"bear", "masha"}) +haystack := "The Bear and Masha" +matches := ac.FindAll(haystack) + +for _, match := range matches { + println(haystack[match.Start():match.End()]) +} +``` + +Matching can be done via `NFA` or `DFA`. +`NFA` has runtime complexity O(N + M) in relation to the haystack and number of matches. +`DFA` has runtime complexity O(N), but it uses more memory. + +Replacing of matches in the haystack. + +`replaceWith` needs to be the same length as the `patterns` +```go +r := ahocorasick.NewReplacer(ac) +replaced := r.ReplaceAll(haystack, replaceWith) +``` + +`ReplaceAllFunc` is useful, for example, if you want to use the original text cassing but you are matching +case insensitively. You can replace partially by return false and from that point, the original string will be preserved. +```go +replaced := r.ReplaceAllFunc(haystack, func(match Match) (string, bool) { + return `` + haystack[match.Start():match.End()] + `<\a>`, true +}) +``` + +Search for matches one at a time via the iterator + +```go +iter := ac.Iter(haystack) + +for next := iter.Next(); next != nil; next = iter.Next() { + ... +} +``` + +It's plenty fast but if you want to use it in parallel, that is also possible. + +Memory consumption won't increase because the read-only automaton is not actually copied, only the counters are. + +The magic line is `ac := ac` + +```go +var w sync.WaitGroup + +w.Add(50) +for i := 0; i < 50; i++ { + go func() { + ac := ac + matches := ac.FindAll(haystack) + println(len(matches)) + w.Done() + }() +} +w.Wait() +``` \ No newline at end of file diff --git a/nstr/ac/ahocorasick.go b/nstr/ac/ahocorasick.go new file mode 100644 index 0000000..2662b88 --- /dev/null +++ b/nstr/ac/ahocorasick.go @@ -0,0 +1,386 @@ +package ac + +import ( + "strings" + "sync" + "unicode" +) + +type findIter struct { + fsm imp + prestate *prefilterState + haystack []byte + pos int + matchOnlyWholeWords bool +} + +// Iter is an iterator over matches found on the current haystack +// it gives the user more granular control. You can choose how many and what kind of matches you need. +type Iter interface { + Next() *Match +} + +// Next gives a pointer to the next match yielded by the iterator or nil, if there is none +func (f *findIter) Next() *Match { + if f.pos > len(f.haystack) { + return nil + } + + result := f.fsm.FindAtNoState(f.prestate, f.haystack, f.pos) + + if result == nil { + return nil + } + + if result.end == f.pos { + f.pos += 1 + } else { + f.pos = result.end + } + + if f.matchOnlyWholeWords { + if result.Start()-1 >= 0 && (unicode.IsLetter(rune(f.haystack[result.Start()-1])) || unicode.IsDigit(rune(f.haystack[result.Start()-1]))) { + return f.Next() + } + if result.end < len(f.haystack) && (unicode.IsLetter(rune(f.haystack[result.end])) || unicode.IsDigit(rune(f.haystack[result.end]))) { + return f.Next() + } + } + + return result +} + +type overlappingIter struct { + fsm imp + prestate *prefilterState + haystack []byte + pos int + stateID stateID + matchIndex int + matchOnlyWholeWords bool +} + +func (f *overlappingIter) Next() *Match { + if f.pos > len(f.haystack) { + return nil + } + + result := f.fsm.OverlappingFindAt(f.prestate, f.haystack, f.pos, &f.stateID, &f.matchIndex) + + if result == nil { + return nil + } + + f.pos = result.End() + + if f.matchOnlyWholeWords { + if result.Start()-1 >= 0 && (unicode.IsLetter(rune(f.haystack[result.Start()-1])) || unicode.IsDigit(rune(f.haystack[result.Start()-1]))) { + return f.Next() + } + if result.end < len(f.haystack) && (unicode.IsLetter(rune(f.haystack[result.end])) || unicode.IsDigit(rune(f.haystack[result.end]))) { + return f.Next() + } + } + + return result +} + +func newOverlappingIter(ac AhoCorasick, haystack []byte) overlappingIter { + prestate := prefilterState{ + skips: 0, + skipped: 0, + maxMatchLen: ac.i.MaxPatternLen(), + inert: false, + lastScanAt: 0, + } + return overlappingIter{ + fsm: ac.i, + prestate: &prestate, + haystack: haystack, + pos: 0, + stateID: ac.i.StartState(), + matchIndex: 0, + matchOnlyWholeWords: ac.matchOnlyWholeWords, + } +} + +// make sure the AhoCorasick data structure implements the Finder interface +var _ Finder = (*AhoCorasick)(nil) + +// AhoCorasick is the main data structure that does most of the work +type AhoCorasick struct { + i imp + matchKind matchKind + matchOnlyWholeWords bool +} + +func (ac AhoCorasick) PatternCount() int { + return ac.i.PatternCount() +} + +// Iter gives an iterator over the built patterns +func (ac AhoCorasick) Iter(haystack string) Iter { + return ac.IterByte([]byte(haystack)) +} + +// IterByte gives an iterator over the built patterns +func (ac AhoCorasick) IterByte(haystack []byte) Iter { + prestate := &prefilterState{ + skips: 0, + skipped: 0, + maxMatchLen: ac.i.MaxPatternLen(), + inert: false, + lastScanAt: 0, + } + + return &findIter{ + fsm: ac.i, + prestate: prestate, + haystack: haystack, + pos: 0, + matchOnlyWholeWords: ac.matchOnlyWholeWords, + } +} + +// IterOverlapping gives an iterator over the built patterns with overlapping matches +func (ac AhoCorasick) IterOverlapping(haystack string) Iter { + return ac.IterOverlappingByte([]byte(haystack)) +} + +// IterOverlappingByte gives an iterator over the built patterns with overlapping matches +func (ac AhoCorasick) IterOverlappingByte(haystack []byte) Iter { + if ac.matchKind != StandardMatch { + panic("only StandardMatch allowed for overlapping matches") + } + i := newOverlappingIter(ac, haystack) + return &i +} + +var pool = sync.Pool{ + New: func() interface{} { + return strings.Builder{} + }, +} + +type Replacer struct { + finder Finder +} + +//goland:noinspection GoUnusedExportedFunction +func NewReplacer(finder Finder) Replacer { + return Replacer{finder: finder} +} + +// ReplaceAllFunc replaces the matches found in the haystack according to the user provided function +// it gives fine-grained control over what is replaced. +// A user can choose to stop the replacing process early by returning false in the lambda +// In that case, everything from that point will be kept as the original haystack +func (r Replacer) ReplaceAllFunc(haystack string, f func(match Match) (string, bool)) string { + matches := r.finder.FindAll(haystack) + + if len(matches) == 0 { + return haystack + } + + replaceWith := make([]string, 0) + + for _, match := range matches { + rw, ok := f(match) + if !ok { + break + } + replaceWith = append(replaceWith, rw) + } + + str := pool.Get().(strings.Builder) + + defer func() { + str.Reset() + pool.Put(str) + }() + + start := 0 + + for i, match := range matches { + if i >= len(replaceWith) { + str.WriteString(haystack[start:]) + return str.String() + } + str.WriteString(haystack[start:match.Start()]) + str.WriteString(replaceWith[i]) + start = match.Start() + match.len + } + + if start-1 < len(haystack) { + str.WriteString(haystack[start:]) + } + + return str.String() +} + +// ReplaceAll replaces the matches found in the haystack according to the user provided slice `replaceWith` +// It panics, if `replaceWith` has length different from the patterns that it was built with +func (r Replacer) ReplaceAll(haystack string, replaceWith []string) string { + if len(replaceWith) != r.finder.PatternCount() { + panic("replaceWith needs to have the same length as the pattern count") + } + + return r.ReplaceAllFunc(haystack, func(match Match) (string, bool) { + return replaceWith[match.pattern], true + }) +} + +type Finder interface { + FindAll(haystack string) []Match + PatternCount() int +} + +// FindAll returns the matches found in the haystack +func (ac AhoCorasick) FindAll(haystack string) []Match { + iter := ac.Iter(haystack) + matches := make([]Match, 0) + + for { + next := iter.Next() + if next == nil { + break + } + + matches = append(matches, *next) + } + + return matches +} + +// AhoCorasickBuilder defines a set of options applied before the patterns are built +type AhoCorasickBuilder struct { + dfaBuilder *iDFABuilder + nfaBuilder *iNFABuilder + dfa bool + matchOnlyWholeWords bool +} + +// Opts defines a set of options applied before the patterns are built +type Opts struct { + AsciiCaseInsensitive bool + MatchOnlyWholeWords bool + MatchKind matchKind + DFA bool +} + +// NewAhoCorasickBuilder creates a new AhoCorasickBuilder based on Opts +//goland:noinspection GoUnusedExportedFunction +func NewAhoCorasickBuilder(o Opts) AhoCorasickBuilder { + return AhoCorasickBuilder{ + dfaBuilder: newDFABuilder(), + nfaBuilder: newNFABuilder(o.MatchKind, o.AsciiCaseInsensitive), + dfa: o.DFA, + matchOnlyWholeWords: o.MatchOnlyWholeWords, + } +} + +// Build builds a (non)deterministic finite automata from the user provided patterns +func (a *AhoCorasickBuilder) Build(patterns []string) AhoCorasick { + bytePatterns := make([][]byte, len(patterns)) + for i, pat := range patterns { + bytePatterns[i] = []byte(pat) + } + + return a.BuildByte(bytePatterns) +} + +// BuildByte builds a (non)deterministic finite automata from the user provided patterns +func (a *AhoCorasickBuilder) BuildByte(patterns [][]byte) AhoCorasick { + nfa := a.nfaBuilder.build(patterns) + kind := nfa.matchKind + + if a.dfa { + dfa := a.dfaBuilder.build(nfa) + return AhoCorasick{dfa, kind, a.matchOnlyWholeWords} + } + + return AhoCorasick{nfa, kind, a.matchOnlyWholeWords} +} + +type imp interface { + MatchKind() *matchKind + StartState() stateID + MaxPatternLen() int + PatternCount() int + Prefilter() prefilter + UsePrefilter() bool + OverlappingFindAt(prestate *prefilterState, haystack []byte, at int, stateId *stateID, matchIndex *int) *Match + EarliestFindAt(prestate *prefilterState, haystack []byte, at int, stateId *stateID) *Match + FindAtNoState(prestate *prefilterState, haystack []byte, at int) *Match +} + +type matchKind int + +const ( + // StandardMatch Use standard match semantics, which support overlapping matches. When + // used with non-overlapping matches, matches are reported as they are seen. + StandardMatch matchKind = iota + // LeftMostFirstMatch Use leftmost-first match semantics, which reports leftmost matches. + // When there are multiple possible leftmost matches, the match + // corresponding to the pattern that appeared earlier when constructing + // the automaton is reported. + // This does **not** support overlapping matches or stream searching + LeftMostFirstMatch + // LeftMostLongestMatch Use leftmost-longest match semantics, which reports leftmost matches. + // When there are multiple possible leftmost matches, the longest match is chosen. + LeftMostLongestMatch +) + +func (m matchKind) supportsOverlapping() bool { + return m.isStandard() +} + +func (m matchKind) supportsStream() bool { + return m.isStandard() +} + +func (m matchKind) isStandard() bool { + return m == StandardMatch +} + +func (m matchKind) isLeftmost() bool { + return m == LeftMostFirstMatch || m == LeftMostLongestMatch +} + +func (m matchKind) isLeftmostFirst() bool { + return m == LeftMostFirstMatch +} + +// Match A representation of a match reported by an Aho-Corasick automaton. +// +// A match has two essential pieces of information: the identifier of the +// pattern that matched, along with the start and end offsets of the match +// in the haystack. +type Match struct { + pattern int + len int + end int +} + +// Pattern returns the index of the pattern in the slice of the patterns provided by the user that +// was matched +func (m *Match) Pattern() int { + return m.pattern +} + +// End gives the index of the last character of this match inside the haystack +func (m *Match) End() int { + return m.end +} + +// Start gives the index of the first character of this match inside the haystack +func (m *Match) Start() int { + return m.end - m.len +} + +type stateID uint + +const ( + failedStateID stateID = 0 + deadStateID stateID = 1 +) diff --git a/nstr/ac/automaton.go b/nstr/ac/automaton.go new file mode 100644 index 0000000..b9d7817 --- /dev/null +++ b/nstr/ac/automaton.go @@ -0,0 +1,222 @@ +package ac + +type automaton interface { + Repr() *iRepr + MatchKind() *matchKind + Anchored() bool + Prefilter() prefilter + StartState() stateID + IsValid(stateID) bool + IsMatchState(stateID) bool + IsMatchOrDeadState(stateID) bool + GetMatch(stateID, int, int) *Match + MatchCount(stateID) int + NextState(stateID, byte) stateID + NextStateNoFail(stateID, byte) stateID + StandardFindAt(*prefilterState, []byte, int, *stateID) *Match + StandardFindAtImp(*prefilterState, prefilter, []byte, int, *stateID) *Match + LeftmostFindAt(*prefilterState, []byte, int, *stateID) *Match + LeftmostFindAtImp(*prefilterState, prefilter, []byte, int, *stateID) *Match + LeftmostFindAtNoState(*prefilterState, []byte, int) *Match + LeftmostFindAtNoStateImp(*prefilterState, prefilter, []byte, int) *Match + OverlappingFindAt(*prefilterState, []byte, int, *stateID, *int) *Match + EarliestFindAt(*prefilterState, []byte, int, *stateID) *Match + FindAt(*prefilterState, []byte, int, *stateID) *Match + FindAtNoState(*prefilterState, []byte, int) *Match +} + +func isMatchOrDeadState(a automaton, si stateID) bool { + return si == deadStateID || a.IsMatchState(si) +} + +func standardFindAt(a automaton, prestate *prefilterState, haystack []byte, at int, sID *stateID) *Match { + pre := a.Prefilter() + return a.StandardFindAtImp(prestate, pre, haystack, at, sID) +} + +func standardFindAtImp(a automaton, prestate *prefilterState, prefilter prefilter, haystack []byte, at int, sID *stateID) *Match { + for at < len(haystack) { + if prefilter != nil { + startState := a.StartState() + if prestate.IsEffective(at) && sID == &startState { + c, typ := nextPrefilter(prestate, prefilter, haystack, at) + switch typ { + case noneCandidate: + return nil + case possibleStartOfMatchCandidate: + i := c.(int) + at = i + } + } + } + *sID = a.NextStateNoFail(*sID, haystack[at]) + at += 1 + + if a.IsMatchOrDeadState(*sID) { + if *sID == deadStateID { + return nil + } else { + return a.GetMatch(*sID, 0, at) + } + } + } + return nil +} + +func leftmostFindAt(a automaton, prestate *prefilterState, haystack []byte, at int, sID *stateID) *Match { + prefilter := a.Prefilter() + return a.LeftmostFindAtImp(prestate, prefilter, haystack, at, sID) +} + +func leftmostFindAtImp(a automaton, prestate *prefilterState, prefilter prefilter, haystack []byte, at int, sID *stateID) *Match { + if a.Anchored() && at > 0 && *sID == a.StartState() { + return nil + } + lastMatch := a.GetMatch(*sID, 0, at) + + for at < len(haystack) { + if prefilter != nil { + startState := a.StartState() + if prestate.IsEffective(at) && sID == &startState { + c, typ := nextPrefilter(prestate, prefilter, haystack, at) + switch typ { + case noneCandidate: + return nil + case possibleStartOfMatchCandidate: + i := c.(int) + at = i + } + } + } + + *sID = a.NextStateNoFail(*sID, haystack[at]) + at += 1 + + if a.IsMatchOrDeadState(*sID) { + if *sID == deadStateID { + return lastMatch + } else { + a.GetMatch(*sID, 0, at) + } + } + } + + return lastMatch +} + +func leftmostFindAtNoState(a automaton, prestate *prefilterState, haystack []byte, at int) *Match { + return leftmostFindAtNoStateImp(a, prestate, a.Prefilter(), haystack, at) +} + +func leftmostFindAtNoStateImp(a automaton, prestate *prefilterState, prefilter prefilter, haystack []byte, at int) *Match { + if a.Anchored() && at > 0 { + return nil + } + if prefilter != nil && !prefilter.ReportsFalsePositives() { + c, typ := prefilter.NextCandidate(prestate, haystack, at) + switch typ { + case noneCandidate: + return nil + case matchCandidate: + m := c.(*Match) + return m + } + } + + stateID := a.StartState() + lastMatch := a.GetMatch(stateID, 0, at) + + for at < len(haystack) { + if prefilter != nil && prestate.IsEffective(at) && stateID == a.StartState() { + c, typ := prefilter.NextCandidate(prestate, haystack, at) + switch typ { + case noneCandidate: + return nil + case matchCandidate: + m := c.(*Match) + return m + case possibleStartOfMatchCandidate: + i := c.(int) + at = i + } + } + + stateID = a.NextStateNoFail(stateID, haystack[at]) + at += 1 + + if a.IsMatchOrDeadState(stateID) { + if stateID == deadStateID { + return lastMatch + } + lastMatch = a.GetMatch(stateID, 0, at) + } + } + + return lastMatch +} + +func overlappingFindAt(a automaton, prestate *prefilterState, haystack []byte, at int, id *stateID, matchIndex *int) *Match { + if a.Anchored() && at > 0 && *id == a.StartState() { + return nil + } + + matchCount := a.MatchCount(*id) + + if *matchIndex < matchCount { + result := a.GetMatch(*id, *matchIndex, at) + *matchIndex += 1 + return result + } + + *matchIndex = 0 + match := a.StandardFindAt(prestate, haystack, at, id) + + if match == nil { + return nil + } + + *matchIndex = 1 + return match +} + +func earliestFindAt(a automaton, prestate *prefilterState, haystack []byte, at int, id *stateID) *Match { + if *id == a.StartState() { + if a.Anchored() && at > 0 { + return nil + } + match := a.GetMatch(*id, 0, at) + if match != nil { + return match + } + } + return a.StandardFindAt(prestate, haystack, at, id) +} + +func findAt(a automaton, prestate *prefilterState, haystack []byte, at int, id *stateID) *Match { + kind := a.MatchKind() + if kind == nil { + return nil + } + switch *kind { + case StandardMatch: + return a.EarliestFindAt(prestate, haystack, at, id) + case LeftMostFirstMatch, LeftMostLongestMatch: + return a.LeftmostFindAt(prestate, haystack, at, id) + } + return nil +} + +func findAtNoState(a automaton, prestate *prefilterState, haystack []byte, at int) *Match { + kind := a.MatchKind() + if kind == nil { + return nil + } + switch *kind { + case StandardMatch: + state := a.StartState() + return a.EarliestFindAt(prestate, haystack, at, &state) + case LeftMostFirstMatch, LeftMostLongestMatch: + return a.LeftmostFindAtNoState(prestate, haystack, at) + } + return nil +} diff --git a/nstr/ac/byte_frequencies.go b/nstr/ac/byte_frequencies.go new file mode 100644 index 0000000..8ab4111 --- /dev/null +++ b/nstr/ac/byte_frequencies.go @@ -0,0 +1,260 @@ +package ac + +var byteFrequencies = [256]byte{ + 55, // '\x00' + 52, // '\x01' + 51, // '\x02' + 50, // '\x03' + 49, // '\x04' + 48, // '\x05' + 47, // '\x06' + 46, // '\x07' + 45, // '\x08' + 103, // '\t' + 242, // '\n' + 66, // '\x0b' + 67, // '\x0c' + 229, // '\r' + 44, // '\x0e' + 43, // '\x0f' + 42, // '\x10' + 41, // '\x11' + 40, // '\x12' + 39, // '\x13' + 38, // '\x14' + 37, // '\x15' + 36, // '\x16' + 35, // '\x17' + 34, // '\x18' + 33, // '\x19' + 56, // '\x1a' + 32, // '\x1b' + 31, // '\x1c' + 30, // '\x1d' + 29, // '\x1e' + 28, // '\x1f' + 255, // ' ' + 148, // '!' + 164, // '"' + 149, // '#' + 136, // '$' + 160, // '%' + 155, // '&' + 173, // "'" + 221, // '(' + 222, // ')' + 134, // '*' + 122, // '+' + 232, // ',' + 202, // '-' + 215, // '.' + 224, // '/' + 208, // '0' + 220, // '1' + 204, // '2' + 187, // '3' + 183, // '4' + 179, // '5' + 177, // '6' + 168, // '7' + 178, // '8' + 200, // '9' + 226, // ':' + 195, // ';' + 154, // '<' + 184, // '=' + 174, // '>' + 126, // '?' + 120, // '@' + 191, // 'A' + 157, // 'B' + 194, // 'C' + 170, // 'D' + 189, // 'E' + 162, // 'F' + 161, // 'G' + 150, // 'H' + 193, // 'I' + 142, // 'J' + 137, // 'K' + 171, // 'L' + 176, // 'M' + 185, // 'N' + 167, // 'O' + 186, // 'P' + 112, // 'Q' + 175, // 'R' + 192, // 'S' + 188, // 'T' + 156, // 'U' + 140, // 'V' + 143, // 'W' + 123, // 'X' + 133, // 'Y' + 128, // 'Z' + 147, // '[' + 138, // '\\' + 146, // ']' + 114, // '^' + 223, // '_' + 151, // '`' + 249, // 'a' + 216, // 'b' + 238, // 'c' + 236, // 'd' + 253, // 'e' + 227, // 'f' + 218, // 'g' + 230, // 'h' + 247, // 'i' + 135, // 'j' + 180, // 'k' + 241, // 'l' + 233, // 'm' + 246, // 'n' + 244, // 'o' + 231, // 'p' + 139, // 'q' + 245, // 'r' + 243, // 's' + 251, // 't' + 235, // 'u' + 201, // 'v' + 196, // 'w' + 240, // 'x' + 214, // 'y' + 152, // 'z' + 182, // '{' + 205, // '|' + 181, // '}' + 127, // '~' + 27, // '\x7f' + 212, // '\x80' + 211, // '\x81' + 210, // '\x82' + 213, // '\x83' + 228, // '\x84' + 197, // '\x85' + 169, // '\x86' + 159, // '\x87' + 131, // '\x88' + 172, // '\x89' + 105, // '\x8a' + 80, // '\x8b' + 98, // '\x8c' + 96, // '\x8d' + 97, // '\x8e' + 81, // '\x8f' + 207, // '\x90' + 145, // '\x91' + 116, // '\x92' + 115, // '\x93' + 144, // '\x94' + 130, // '\x95' + 153, // '\x96' + 121, // '\x97' + 107, // '\x98' + 132, // '\x99' + 109, // '\x9a' + 110, // '\x9b' + 124, // '\x9c' + 111, // '\x9d' + 82, // '\x9e' + 108, // '\x9f' + 118, // '\xa0' + 141, // '¡' + 113, // '¢' + 129, // '£' + 119, // '¤' + 125, // '¥' + 165, // '¦' + 117, // '§' + 92, // '¨' + 106, // '©' + 83, // 'ª' + 72, // '«' + 99, // '¬' + 93, // '\xad' + 65, // '®' + 79, // '¯' + 166, // '°' + 237, // '±' + 163, // '²' + 199, // '³' + 190, // '´' + 225, // 'µ' + 209, // '¶' + 203, // '·' + 198, // '¸' + 217, // '¹' + 219, // 'º' + 206, // '»' + 234, // '¼' + 248, // '½' + 158, // '¾' + 239, // '¿' + 255, // 'À' + 255, // 'Á' + 255, // 'Â' + 255, // 'Ã' + 255, // 'Ä' + 255, // 'Å' + 255, // 'Æ' + 255, // 'Ç' + 255, // 'È' + 255, // 'É' + 255, // 'Ê' + 255, // 'Ë' + 255, // 'Ì' + 255, // 'Í' + 255, // 'Î' + 255, // 'Ï' + 255, // 'Ð' + 255, // 'Ñ' + 255, // 'Ò' + 255, // 'Ó' + 255, // 'Ô' + 255, // 'Õ' + 255, // 'Ö' + 255, // '×' + 255, // 'Ø' + 255, // 'Ù' + 255, // 'Ú' + 255, // 'Û' + 255, // 'Ü' + 255, // 'Ý' + 255, // 'Þ' + 255, // 'ß' + 255, // 'à' + 255, // 'á' + 255, // 'â' + 255, // 'ã' + 255, // 'ä' + 255, // 'å' + 255, // 'æ' + 255, // 'ç' + 255, // 'è' + 255, // 'é' + 255, // 'ê' + 255, // 'ë' + 255, // 'ì' + 255, // 'í' + 255, // 'î' + 255, // 'ï' + 255, // 'ð' + 255, // 'ñ' + 255, // 'ò' + 255, // 'ó' + 255, // 'ô' + 255, // 'õ' + 255, // 'ö' + 255, // '÷' + 255, // 'ø' + 255, // 'ù' + 255, // 'ú' + 255, // 'û' + 255, // 'ü' + 255, // 'ý' + 255, // 'þ' + 255, // 'ÿ' +} diff --git a/nstr/ac/classes.go b/nstr/ac/classes.go new file mode 100644 index 0000000..0db2b8b --- /dev/null +++ b/nstr/ac/classes.go @@ -0,0 +1,77 @@ +package ac + +import "math" + +type byteClassRepresentatives struct { + classes *byteClasses + bbyte int + lastClass *byte +} + +func (b *byteClassRepresentatives) next() *byte { + for b.bbyte < 256 { + bbyte := byte(b.bbyte) + class := b.classes.bytes[bbyte] + b.bbyte += 1 + + if b.lastClass == nil || *b.lastClass != class { + c := class + b.lastClass = &c + return &bbyte + } + } + return nil +} + +type byteClassBuilder []bool + +func (b byteClassBuilder) setRange(start, end byte) { + if start > 0 { + b[int(start)-1] = true + } + b[int(end)] = true +} + +func (b byteClassBuilder) build() byteClasses { + var classes byteClasses + var class byte + i := 0 + for { + classes.bytes[byte(i)] = class + if i >= 255 { + break + } + if b[i] { + if class+1 > math.MaxUint8 { + panic("shit happens") + } + class += 1 + } + i += 1 + } + return classes +} + +func newByteClassBuilder() byteClassBuilder { + return make([]bool, 256) +} + +type byteClasses struct { + bytes [256]byte +} + +func singletons() byteClasses { + var bc byteClasses + for i := range bc.bytes { + bc.bytes[i] = byte(i) + } + return bc +} + +func (b byteClasses) alphabetLen() int { + return int(b.bytes[255]) + 1 +} + +func (b byteClasses) isSingleton() bool { + return b.alphabetLen() == 256 +} diff --git a/nstr/ac/dfa.go b/nstr/ac/dfa.go new file mode 100644 index 0000000..ea47861 --- /dev/null +++ b/nstr/ac/dfa.go @@ -0,0 +1,729 @@ +package ac + +import "unsafe" + +type iDFA struct { + atom automaton +} + +func (d iDFA) MatchKind() *matchKind { + return d.atom.MatchKind() +} + +func (d iDFA) StartState() stateID { + return d.atom.StartState() +} + +func (d iDFA) MaxPatternLen() int { + return d.atom.Repr().maxPatternLen +} + +func (d iDFA) PatternCount() int { + return d.atom.Repr().patternCount +} + +func (d iDFA) Prefilter() prefilter { + return d.atom.Prefilter() +} + +func (d iDFA) UsePrefilter() bool { + p := d.Prefilter() + if p == nil { + return false + } + return !p.LooksForNonStartOfMatch() +} + +func (d iDFA) OverlappingFindAt(prestate *prefilterState, haystack []byte, at int, stateId *stateID, matchIndex *int) *Match { + return overlappingFindAt(d.atom, prestate, haystack, at, stateId, matchIndex) +} + +func (d iDFA) EarliestFindAt(prestate *prefilterState, haystack []byte, at int, stateId *stateID) *Match { + return earliestFindAt(d.atom, prestate, haystack, at, stateId) +} + +func (d iDFA) FindAtNoState(prestate *prefilterState, haystack []byte, at int) *Match { + return findAtNoState(d.atom, prestate, haystack, at) +} + +func (d iDFA) LeftmostFindAtNoState(prestate *prefilterState, haystack []byte, at int) *Match { + return leftmostFindAtNoState(d.atom, prestate, haystack, at) +} + +type iDFABuilder struct { + premultiply bool + byteClasses bool +} + +func (d *iDFABuilder) build(nfa *iNFA) iDFA { + var bc byteClasses + if d.byteClasses { + bc = nfa.byteClasses + } else { + bc = singletons() + } + + alphabetLen := bc.alphabetLen() + trans := make([]stateID, alphabetLen*len(nfa.states)) + for i := range trans { + trans[i] = failedStateID + } + + matches := make([][]pattern, len(nfa.states)) + var p prefilter + + if nfa.prefilter != nil { + p = nfa.prefilter.clone() + } + + rep := iRepr{ + matchKind: nfa.matchKind, + anchored: nfa.anchored, + premultiplied: false, + startId: nfa.startID, + maxPatternLen: nfa.maxPatternLen, + patternCount: nfa.patternCount, + stateCount: len(nfa.states), + maxMatch: failedStateID, + heapBytes: 0, + prefilter: p, + byteClasses: bc, + trans: trans, + matches: matches, + } + + for id := 0; id < len(nfa.states); id += 1 { + rep.matches[id] = append(rep.matches[id], nfa.states[id].matches...) + fail := nfa.states[id].fail + + nfa.iterAllTransitions(&bc, stateID(id), func(tr *next) { + if tr.id == failedStateID { + tr.id = nfaNextStateMemoized(nfa, &rep, stateID(id), fail, tr.key) + } + rep.setNextState(stateID(id), tr.key, tr.id) + }) + + } + + rep.shuffleMatchStates() + rep.calculateSize() + + if d.premultiply { + rep.premultiply() + if bc.isSingleton() { + return iDFA{&iPremultiplied{rep}} + } else { + return iDFA{&iPremultipliedByteClass{&rep}} + } + } + if bc.isSingleton() { + return iDFA{&iStandard{rep}} + } + return iDFA{&iByteClass{&rep}} +} + +type iByteClass struct { + repr *iRepr +} + +func (p iByteClass) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return findAtNoState(p, prefilterState, bytes, i) +} + +func (p iByteClass) Repr() *iRepr { + return p.repr +} + +func (p iByteClass) MatchKind() *matchKind { + return &p.repr.matchKind +} + +func (p iByteClass) Anchored() bool { + return p.repr.anchored +} + +func (p iByteClass) Prefilter() prefilter { + return p.repr.prefilter +} + +func (p iByteClass) StartState() stateID { + return p.repr.startId +} + +func (p iByteClass) IsValid(id stateID) bool { + return int(id) < p.repr.stateCount +} + +func (p iByteClass) IsMatchState(id stateID) bool { + return p.repr.isMatchState(id) +} + +func (p iByteClass) IsMatchOrDeadState(id stateID) bool { + return p.repr.isMatchStateOrDeadState(id) +} + +func (p iByteClass) GetMatch(id stateID, i int, i2 int) *Match { + return p.repr.GetMatch(id, i, i2) +} + +func (p iByteClass) MatchCount(id stateID) int { + return p.repr.MatchCount(id) +} + +func (p iByteClass) NextState(id stateID, b2 byte) stateID { + alphabetLen := p.repr.byteClasses.alphabetLen() + input := p.repr.byteClasses.bytes[b2] + o := int(id)*alphabetLen + int(input) + return p.repr.trans[o] +} + +func (p iByteClass) NextStateNoFail(id stateID, b byte) stateID { + next := p.NextState(id, b) + if next == failedStateID { + panic("automaton should never return fail_id for next state") + } + return next +} + +func (p iByteClass) StandardFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return standardFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iByteClass) StandardFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return standardFindAtImp(&p, prefilterState, prefilter, bytes, i, id) +} + +func (p iByteClass) LeftmostFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iByteClass) LeftmostFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAtImp(&p, prefilterState, prefilter, bytes, i, id) +} + +func (p iByteClass) LeftmostFindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return leftmostFindAtNoState(&p, prefilterState, bytes, i) +} + +func (p iByteClass) LeftmostFindAtNoStateImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int) *Match { + return leftmostFindAtNoStateImp(&p, prefilterState, prefilter, bytes, i) +} + +func (p iByteClass) OverlappingFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match { + return overlappingFindAt(&p, prefilterState, bytes, i, id, i2) +} + +func (p iByteClass) EarliestFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return earliestFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iByteClass) FindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return findAt(&p, prefilterState, bytes, i, id) +} + +type iPremultipliedByteClass struct { + repr *iRepr +} + +func (p iPremultipliedByteClass) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return findAtNoState(p, prefilterState, bytes, i) +} + +func (p iPremultipliedByteClass) Repr() *iRepr { + return p.repr +} + +func (p iPremultipliedByteClass) MatchKind() *matchKind { + return &p.repr.matchKind +} + +func (p iPremultipliedByteClass) Anchored() bool { + return p.repr.anchored +} + +func (p iPremultipliedByteClass) Prefilter() prefilter { + return p.repr.prefilter +} + +func (p iPremultipliedByteClass) StartState() stateID { + return p.repr.startId +} + +func (p iPremultipliedByteClass) IsValid(id stateID) bool { + return (int(id) / p.repr.alphabetLen()) < p.repr.stateCount +} + +func (p iPremultipliedByteClass) IsMatchState(id stateID) bool { + return p.repr.isMatchState(id) +} + +func (p iPremultipliedByteClass) IsMatchOrDeadState(id stateID) bool { + return p.repr.isMatchStateOrDeadState(id) +} + +func (p iPremultipliedByteClass) GetMatch(id stateID, matchIndex int, end int) *Match { + if id > p.repr.maxMatch { + return nil + } + + m := p.repr.matches[int(id)/p.repr.alphabetLen()][matchIndex] + return &Match{ + pattern: m.PatternID, + len: m.PatternLength, + end: end, + } +} + +func (p iPremultipliedByteClass) MatchCount(id stateID) int { + o := int(id) / p.repr.alphabetLen() + return len(p.repr.matches[o]) +} + +func (p iPremultipliedByteClass) NextState(id stateID, b byte) stateID { + input := p.repr.byteClasses.bytes[b] + o := int(id) + int(input) + return p.repr.trans[o] +} + +func (p iPremultipliedByteClass) NextStateNoFail(id stateID, b byte) stateID { + // TODO this leaks garbage + n := p.NextState(id, b) + if n == failedStateID { + panic("automaton should never return fail_id for next state") + } + return n +} + +func (p iPremultipliedByteClass) StandardFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return standardFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iPremultipliedByteClass) StandardFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return standardFindAtImp(&p, prefilterState, prefilter, bytes, i, id) +} + +func (p iPremultipliedByteClass) LeftmostFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iPremultipliedByteClass) LeftmostFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAtImp(&p, prefilterState, prefilter, bytes, i, id) +} + +func (p iPremultipliedByteClass) LeftmostFindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return leftmostFindAtNoState(&p, prefilterState, bytes, i) +} + +func (p iPremultipliedByteClass) LeftmostFindAtNoStateImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int) *Match { + return leftmostFindAtNoStateImp(&p, prefilterState, prefilter, bytes, i) +} + +func (p iPremultipliedByteClass) OverlappingFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match { + return overlappingFindAt(&p, prefilterState, bytes, i, id, i2) +} + +func (p iPremultipliedByteClass) EarliestFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return earliestFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iPremultipliedByteClass) FindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return findAt(&p, prefilterState, bytes, i, id) +} + +type iPremultiplied struct { + repr iRepr +} + +func (p iPremultiplied) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return findAtNoState(p, prefilterState, bytes, i) +} + +func (p iPremultiplied) Repr() *iRepr { + return &p.repr +} + +func (p iPremultiplied) MatchKind() *matchKind { + return &p.repr.matchKind +} + +func (p iPremultiplied) Anchored() bool { + return p.repr.anchored +} + +func (p iPremultiplied) Prefilter() prefilter { + return p.repr.prefilter +} + +func (p iPremultiplied) StartState() stateID { + return p.repr.startId +} + +func (p iPremultiplied) IsValid(id stateID) bool { + return int(id)/256 < p.repr.stateCount +} + +func (p iPremultiplied) IsMatchState(id stateID) bool { + return p.repr.isMatchState(id) +} + +func (p iPremultiplied) IsMatchOrDeadState(id stateID) bool { + return p.repr.isMatchStateOrDeadState(id) +} + +func (p iPremultiplied) GetMatch(id stateID, matchIndex int, end int) *Match { + if id > p.repr.maxMatch { + return nil + } + m := p.repr.matches[int(id)/256][matchIndex] + return &Match{ + pattern: m.PatternID, + len: m.PatternLength, + end: end, + } +} + +func (p iPremultiplied) MatchCount(id stateID) int { + return len(p.repr.matches[int(id)/256]) +} + +func (p iPremultiplied) NextState(id stateID, b byte) stateID { + o := int(id) + int(b) + return p.repr.trans[o] +} + +func (p iPremultiplied) NextStateNoFail(id stateID, b byte) stateID { + next := p.NextState(id, b) + if next == failedStateID { + panic("automaton should never return fail_id for next state") + } + return next +} + +func (p iPremultiplied) StandardFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return standardFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iPremultiplied) StandardFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return standardFindAtImp(&p, prefilterState, prefilter, bytes, i, id) +} + +func (p iPremultiplied) LeftmostFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iPremultiplied) LeftmostFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAtImp(&p, prefilterState, prefilter, bytes, i, id) +} + +func (p iPremultiplied) LeftmostFindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return leftmostFindAtNoState(&p, prefilterState, bytes, i) +} + +func (p iPremultiplied) LeftmostFindAtNoStateImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int) *Match { + return leftmostFindAtNoStateImp(&p, prefilterState, prefilter, bytes, i) +} + +func (p iPremultiplied) OverlappingFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match { + return overlappingFindAt(&p, prefilterState, bytes, i, id, i2) +} + +func (p iPremultiplied) EarliestFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return earliestFindAt(&p, prefilterState, bytes, i, id) +} + +func (p iPremultiplied) FindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return findAt(&p, prefilterState, bytes, i, id) +} + +func nfaNextStateMemoized(nfa *iNFA, dfa *iRepr, populating stateID, current stateID, input byte) stateID { + for { + if current < populating { + return dfa.nextState(current, input) + } + + next := nfa.states[current].nextState(input) + + if next != failedStateID { + return next + } + current = nfa.states[current].fail + } +} + +func newDFABuilder() *iDFABuilder { + return &iDFABuilder{ + premultiply: true, + byteClasses: true, + } +} + +type iStandard struct { + repr iRepr +} + +func (p *iStandard) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return findAtNoState(p, prefilterState, bytes, i) +} + +func (p *iStandard) Repr() *iRepr { + return &p.repr +} + +func (p *iStandard) MatchKind() *matchKind { + return &p.repr.matchKind +} + +func (p *iStandard) Anchored() bool { + return p.repr.anchored +} + +func (p *iStandard) Prefilter() prefilter { + return p.repr.prefilter +} + +func (p *iStandard) StartState() stateID { + return p.repr.startId +} + +func (p *iStandard) IsValid(id stateID) bool { + return int(id) < p.repr.stateCount +} + +func (p *iStandard) IsMatchState(id stateID) bool { + return p.repr.isMatchState(id) +} + +func (p *iStandard) IsMatchOrDeadState(id stateID) bool { + return p.repr.isMatchStateOrDeadState(id) +} + +func (p *iStandard) GetMatch(id stateID, matchIndex int, end int) *Match { + return p.repr.GetMatch(id, matchIndex, end) +} + +func (p *iStandard) MatchCount(id stateID) int { + return p.repr.MatchCount(id) +} + +func (p *iStandard) NextState(current stateID, input byte) stateID { + o := int(current)*256 + int(input) + return p.repr.trans[o] +} + +func (p *iStandard) NextStateNoFail(id stateID, b byte) stateID { + next := p.NextState(id, b) + if next == failedStateID { + panic("automaton should never return fail_id for next state") + } + return next +} + +func (p *iStandard) StandardFindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match { + return standardFindAt(p, state, bytes, i, id) +} + +func (p *iStandard) StandardFindAtImp(state *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return standardFindAtImp(p, state, prefilter, bytes, i, id) +} + +func (p *iStandard) LeftmostFindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAt(p, state, bytes, i, id) +} + +func (p *iStandard) LeftmostFindAtImp(state *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAtImp(p, state, prefilter, bytes, i, id) +} + +func (p *iStandard) LeftmostFindAtNoState(state *prefilterState, bytes []byte, i int) *Match { + return leftmostFindAtNoState(p, state, bytes, i) +} + +func (p *iStandard) LeftmostFindAtNoStateImp(state *prefilterState, prefilter prefilter, bytes []byte, i int) *Match { + return leftmostFindAtNoStateImp(p, state, prefilter, bytes, i) +} + +func (p *iStandard) OverlappingFindAt(state *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match { + return overlappingFindAt(p, state, bytes, i, id, i2) +} + +func (p *iStandard) EarliestFindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match { + return earliestFindAt(p, state, bytes, i, id) +} + +func (p *iStandard) FindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match { + return findAt(p, state, bytes, i, id) +} + +type iRepr struct { + matchKind matchKind + anchored bool + premultiplied bool + startId stateID + maxPatternLen int + patternCount int + stateCount int + maxMatch stateID + heapBytes int + prefilter prefilter + byteClasses byteClasses + trans []stateID + matches [][]pattern +} + +func (r *iRepr) premultiply() { + if r.premultiplied || r.stateCount <= 1 { + return + } + alphaLen := r.alphabetLen() + + for id := 2; id < r.stateCount; id++ { + offset := id * alphaLen + slice := r.trans[offset : offset+alphaLen] + for i := range slice { + if slice[i] == deadStateID { + continue + } + slice[i] = stateID(int(slice[i]) * alphaLen) + } + } + r.premultiplied = true + r.startId = stateID(int(r.startId) * alphaLen) + r.maxMatch = stateID(int(r.maxMatch) * alphaLen) +} + +func (r *iRepr) setNextState(from stateID, b byte, to stateID) { + alphabetLen := r.alphabetLen() + b = r.byteClasses.bytes[b] + r.trans[int(from)*alphabetLen+int(b)] = to +} + +func (r *iRepr) alphabetLen() int { + return r.byteClasses.alphabetLen() +} + +func (r *iRepr) nextState(from stateID, b byte) stateID { + alphabetLen := r.alphabetLen() + b = r.byteClasses.bytes[b] + return r.trans[int(from)*alphabetLen+int(b)] +} + +func (r *iRepr) isMatchState(id stateID) bool { + return id <= r.maxMatch && id > deadStateID +} + +func (r *iRepr) isMatchStateOrDeadState(id stateID) bool { + return id <= r.maxMatch +} + +func (r *iRepr) GetMatch(id stateID, matchIndex int, end int) *Match { + i := int(id) + if id > r.maxMatch { + return nil + } + if i > len(r.matches) { + return nil + } + matches := r.matches[int(id)] + if matchIndex > len(matches) { + return nil + } + pattern := matches[matchIndex] + + return &Match{ + pattern: pattern.PatternID, + len: pattern.PatternLength, + end: end, + } +} + +func (r *iRepr) MatchCount(id stateID) int { + return len(r.matches[id]) +} + +func (r *iRepr) swapStates(id1 stateID, id2 stateID) { + if r.premultiplied { + panic("cannot shuffle match states of premultiplied iDFA") + } + + o1 := int(id1) * r.alphabetLen() + o2 := int(id2) * r.alphabetLen() + + for b := 0; b < r.alphabetLen(); b++ { + r.trans[o1+b], r.trans[o2+b] = r.trans[o2+b], r.trans[o1+b] + } + r.matches[int(id1)], r.matches[int(id2)] = r.matches[int(id2)], r.matches[int(id1)] +} + +func (r *iRepr) calculateSize() { + intSize := int(unsafe.Sizeof(stateID(1))) + size := (len(r.trans) * intSize) + (len(r.matches) * (intSize * 3)) + + for _, stateMatches := range r.matches { + size += len(stateMatches) * (intSize * 2) + } + var hb int + if r.prefilter != nil { + hb = r.prefilter.HeapBytes() + } + size += hb + r.heapBytes = size +} + +func (r *iRepr) shuffleMatchStates() { + if r.premultiplied { + panic("cannot shuffle match states of premultiplied iDFA") + } + + if r.stateCount <= 1 { + return + } + + firstNonMatch := int(r.startId) + for firstNonMatch < r.stateCount && len(r.matches[firstNonMatch]) > 0 { + firstNonMatch += 1 + } + swaps := make([]stateID, r.stateCount) + + for i := range swaps { + swaps[i] = failedStateID + } + + cur := r.stateCount - 1 + + for cur > firstNonMatch { + if len(r.matches[cur]) > 0 { + r.swapStates(stateID(cur), stateID(firstNonMatch)) + swaps[cur] = stateID(firstNonMatch) + swaps[firstNonMatch] = stateID(cur) + + firstNonMatch += 1 + for firstNonMatch < cur && len(r.matches[firstNonMatch]) > 0 { + firstNonMatch += 1 + } + } + cur -= 1 + } + + for id := 0; id < r.stateCount; id++ { + alphabetLen := r.alphabetLen() + offset := id * alphabetLen + + slice := r.trans[offset : offset+alphabetLen] + + for i := range slice { + if swaps[slice[i]] != failedStateID { + slice[i] = swaps[slice[i]] + } + } + } + + if swaps[r.startId] != failedStateID { + r.startId = swaps[r.startId] + } + r.maxMatch = stateID(firstNonMatch - 1) +} + +type pattern struct { + PatternID int + PatternLength int +} diff --git a/nstr/ac/nfa.go b/nstr/ac/nfa.go new file mode 100644 index 0000000..c74e5d0 --- /dev/null +++ b/nstr/ac/nfa.go @@ -0,0 +1,822 @@ +package ac + +import ( + "sort" + "unsafe" +) + +type iNFA struct { + matchKind matchKind + startID stateID + maxPatternLen int + patternCount int + heapBytes int + prefilter prefilter + anchored bool + byteClasses byteClasses + states []state +} + +func (n *iNFA) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return findAtNoState(n, prefilterState, bytes, i) +} + +func (n *iNFA) Repr() *iRepr { + return nil +} + +func (n *iNFA) MatchKind() *matchKind { + return &n.matchKind +} + +func (n *iNFA) Anchored() bool { + return n.anchored +} + +func (n *iNFA) Prefilter() prefilter { + return n.prefilter +} + +func (n *iNFA) StartState() stateID { + return n.startID +} + +func (n *iNFA) IsValid(id stateID) bool { + return int(id) < len(n.states) +} + +func (n *iNFA) IsMatchState(id stateID) bool { + return n.state(id).isMatch() +} + +func (n *iNFA) IsMatchOrDeadState(id stateID) bool { + return isMatchOrDeadState(n, id) +} + +func (n *iNFA) MatchCount(id stateID) int { + return len(n.states[id].matches) +} + +func (n *iNFA) NextState(id stateID, b byte) stateID { + for { + state := n.states[id] + next := state.nextState(b) + if next != failedStateID { + return next + } + id = state.fail + } +} + +func (n *iNFA) NextStateNoFail(id stateID, b byte) stateID { + next := n.NextState(id, b) + if next == failedStateID { + panic("automaton should never return fail_id for next state") + } + return next +} + +func (n *iNFA) StandardFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return standardFindAt(n, prefilterState, bytes, i, id) +} + +func (n *iNFA) StandardFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return standardFindAtImp(n, prefilterState, prefilter, bytes, i, id) +} + +func (n *iNFA) LeftmostFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAt(n, prefilterState, bytes, i, id) +} + +func (n *iNFA) LeftmostFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match { + return leftmostFindAtImp(n, prefilterState, prefilter, bytes, i, id) +} + +func (n *iNFA) LeftmostFindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match { + return leftmostFindAtNoState(n, prefilterState, bytes, i) +} + +func (n *iNFA) LeftmostFindAtNoStateImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int) *Match { + return leftmostFindAtNoStateImp(n, prefilterState, prefilter, bytes, i) +} + +func (n *iNFA) OverlappingFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match { + return overlappingFindAt(n, prefilterState, bytes, i, id, i2) +} + +func (n *iNFA) EarliestFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return earliestFindAt(n, prefilterState, bytes, i, id) +} + +func (n *iNFA) FindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match { + return findAt(n, prefilterState, bytes, i, id) +} + +func (n *iNFA) MaxPatternLen() int { + return n.maxPatternLen +} + +func (n *iNFA) PatternCount() int { + return n.patternCount +} + +func (n *iNFA) UsePrefilter() bool { + p := n.Prefilter() + if p == nil { + return false + } + return !p.LooksForNonStartOfMatch() +} + +func (n *iNFA) GetMatch(id stateID, matchIndex int, end int) *Match { + if int(id) >= len(n.states) { + return nil + } + state := n.states[id] + if matchIndex >= len(state.matches) { + return nil + } + pat := state.matches[matchIndex] + return &Match{ + pattern: pat.PatternID, + len: pat.PatternLength, + end: end, + } +} + +func (n *iNFA) addDenseState(depth int) stateID { + d := newDense() + trans := transitions{dense: &d} + id := stateID(len(n.states)) + + fail := n.startID + + if n.anchored { + fail = deadStateID + } + + n.states = append(n.states, state{ + trans: trans, + fail: fail, + matches: nil, + depth: depth, + }) + return id +} + +func (n *iNFA) addSparseState(depth int) stateID { + trans := transitions{sparse: &sparse{inner: nil}} + id := stateID(len(n.states)) + + fail := n.startID + + if n.anchored { + fail = deadStateID + } + + n.states = append(n.states, state{ + trans: trans, + fail: fail, + matches: nil, + depth: depth, + }) + return id +} + +func (n *iNFA) state(id stateID) *state { + return &n.states[int(id)] +} + +type compiler struct { + builder iNFABuilder + prefilter prefilterBuilder + nfa iNFA + byteClassBuilder byteClassBuilder +} + +func (c *compiler) compile(patterns [][]byte) *iNFA { + c.addState(0) + c.addState(0) + c.addState(0) + + c.buildTrie(patterns) + + c.addStartStateLoop() + c.addDeadStateLoop() + + if !c.builder.anchored { + if c.builder.matchKind.isLeftmost() { + c.fillFailureTransitionsLeftmost() + } else { + c.fillFailureTransitionsStandard() + } + } + c.closeStartStateLoop() + + c.nfa.byteClasses = c.byteClassBuilder.build() + if !c.builder.anchored { + c.nfa.prefilter = c.prefilter.build() + } + c.calculateSize() + + return &c.nfa +} + +func (c *compiler) calculateSize() { + var size int + for _, state := range c.nfa.states { + size += state.heapBytes() + } + + c.nfa.heapBytes = size +} + +func (c *compiler) closeStartStateLoop() { + if c.builder.anchored || (c.builder.matchKind.isLeftmost() && c.nfa.state(c.nfa.startID).isMatch()) { + startId := c.nfa.startID + start := c.nfa.state(startId) + + for b := 0; b < 256; b++ { + if start.nextState(byte(b)) == startId { + start.setNextState(byte(b), deadStateID) + } + } + } +} + +type queuedState struct { + id stateID + matchAtDepth *int +} + +func startQueuedState(nfa *iNFA) queuedState { + var matchAtDepth *int + if nfa.states[nfa.startID].isMatch() { + r := 0 + matchAtDepth = &r + } + return queuedState{id: nfa.startID, matchAtDepth: matchAtDepth} +} + +func (q *queuedState) nextQueuedState(nfa *iNFA, id stateID) queuedState { + nextMatchAtDepth := q.nextMatchAtDepth(nfa, id) + return queuedState{id, nextMatchAtDepth} +} + +func (q *queuedState) nextMatchAtDepth( + nfa *iNFA, + next stateID, +) *int { + switch q.matchAtDepth { + case nil: + if !nfa.state(next).isMatch() { + return nil + } + default: + return q.matchAtDepth + } + + depth := nfa.state(next).depth - *nfa.state(next).getLongestMatch() + 1 + return &depth +} + +func (c *compiler) fillFailureTransitionsStandard() { + queue := make([]stateID, 0) + seen := c.queuedSet() + + for b := 0; b < 256; b++ { + next := c.nfa.state(c.nfa.startID).nextState(byte(b)) + if next != c.nfa.startID { + if !seen.contains(next) { + queue = append(queue, next) + seen.insert(next) + } + } + } + + for len(queue) > 0 { + id := queue[0] + queue = queue[1:] + it := newIterTransitions(&c.nfa, id) + + for next := it.next(); next != nil; next = it.next() { + if seen.contains(next.id) { + continue + } + queue = append(queue, next.id) + seen.insert(next.id) + + fail := it.nfa.state(id).fail + for it.nfa.state(fail).nextState(next.key) == failedStateID { + fail = it.nfa.state(fail).fail + } + fail = it.nfa.state(fail).nextState(next.key) + it.nfa.state(next.id).fail = fail + it.nfa.copyMatches(fail, next.id) + } + it.nfa.copyEmptyMatches(id) + } +} + +func (c *compiler) fillFailureTransitionsLeftmost() { + queue := make([]queuedState, 0) + seen := c.queuedSet() + start := startQueuedState(&c.nfa) + + for b := 0; b < 256; b++ { + nextId := c.nfa.state(c.nfa.startID).nextState(byte(b)) + if nextId != start.id { + next := start.nextQueuedState(&c.nfa, nextId) + if !seen.contains(next.id) { + queue = append(queue, next) + seen.insert(next.id) + } + if c.nfa.state(nextId).isMatch() { + c.nfa.state(nextId).fail = deadStateID + } + } + } + + for len(queue) > 0 { + item := queue[0] + queue = queue[1:] + anyTrans := false + it := newIterTransitions(&c.nfa, item.id) + tr := it.next() + for tr != nil { + anyTrans = true + next := item.nextQueuedState(it.nfa, tr.id) + if seen.contains(next.id) { + tr = it.next() + continue + } + queue = append(queue, next) + seen.insert(next.id) + + fail := it.nfa.state(item.id).fail + for it.nfa.state(fail).nextState(tr.key) == failedStateID { + fail = it.nfa.state(fail).fail + } + fail = it.nfa.state(fail).nextState(tr.key) + + if next.matchAtDepth != nil { + failDepth := it.nfa.state(fail).depth + nextDepth := it.nfa.state(next.id).depth + if nextDepth-*next.matchAtDepth+1 > failDepth { + it.nfa.state(next.id).fail = deadStateID + tr = it.next() + continue + } + + if start.id == it.nfa.state(next.id).fail { + panic("states that are match states or follow match states should never have a failure transition back to the start state in leftmost searching") + } + } + it.nfa.state(next.id).fail = fail + it.nfa.copyMatches(fail, next.id) + tr = it.next() + } + if !anyTrans && it.nfa.state(item.id).isMatch() { + it.nfa.state(item.id).fail = deadStateID + } + } +} + +func (n *iNFA) copyEmptyMatches(dst stateID) { + n.copyMatches(n.startID, dst) +} + +func (n *iNFA) copyMatches(src stateID, dst stateID) { + srcState, dstState := n.getTwo(src, dst) + dstState.matches = append(dstState.matches, srcState.matches...) +} + +func (n *iNFA) getTwo(i stateID, j stateID) (*state, *state) { + if i == j { + panic("src and dst should not be equal") + } + + if i < j { + before, after := n.states[0:j], n.states[j:] + return &before[i], &after[0] + } + + before, after := n.states[0:i], n.states[i:] + return &after[0], &before[j] +} + +func (n *iNFA) iterAllTransitions(byteClasses *byteClasses, id stateID, f func(tr *next)) { + n.states[id].trans.iterAll(byteClasses, f) +} + +func newIterTransitions(nfa *iNFA, stateId stateID) iterTransitions { + return iterTransitions{ + nfa: nfa, + stateId: stateId, + cur: 0, + } +} + +type iterTransitions struct { + nfa *iNFA + stateId stateID + cur int +} + +type next struct { + key byte + id stateID +} + +func (i *iterTransitions) next() *next { + sparse := i.nfa.states[int(i.stateId)].trans.sparse + if sparse != nil { + if i.cur >= len(sparse.inner) { + return nil + } + ii := i.cur + i.cur += 1 + return &next{ + key: sparse.inner[ii].b, + id: sparse.inner[ii].s, + } + } + + dense := i.nfa.states[int(i.stateId)].trans.dense + for i.cur < len(dense.inner) { + if i.cur >= 256 { + panic("There are always exactly 255 transitions in dense repr") + } + + b := byte(i.cur) + id := dense.inner[b] + i.cur += 1 + if id != failedStateID { + return &next{ + key: b, + id: id, + } + } + } + return nil +} + +type queuedSet struct { + set map[stateID]struct{} + ind int +} + +func newInertQueuedSet() queuedSet { + return queuedSet{ + set: make(map[stateID]struct{}), + ind: 0, + } +} + +func (q *queuedSet) contains(s stateID) bool { + _, ok := q.set[s] + return ok +} + +func (q *queuedSet) insert(s stateID) { + q.set[s] = struct{}{} +} + +func newActiveQueuedSet() queuedSet { + return queuedSet{ + set: make(map[stateID]struct{}, 0), + ind: 0, + } +} + +func (c *compiler) queuedSet() queuedSet { + if c.builder.asciiCaseInsensitive { + return newActiveQueuedSet() + } + return newInertQueuedSet() +} + +func (c *compiler) addStartStateLoop() { + startId := c.nfa.startID + start := c.nfa.state(startId) + for b := 0; b < 256; b++ { + if start.nextState(byte(b)) == failedStateID { + start.setNextState(byte(b), startId) + } + } +} + +func (c *compiler) addDeadStateLoop() { + dead := c.nfa.state(deadStateID) + for b := 0; b < 256; b++ { + dead.setNextState(byte(b), deadStateID) + } +} + +func (c *compiler) buildTrie(patterns [][]byte) { +Patterns: + for i, pat := range patterns { + c.nfa.maxPatternLen = max(c.nfa.maxPatternLen, len(pat)) + c.nfa.patternCount += 1 + + prev := c.nfa.startID + sawMatch := false + + for depth, b := range pat { + sawMatch = sawMatch || c.nfa.state(prev).isMatch() + if c.builder.matchKind.isLeftmostFirst() && sawMatch { + continue Patterns + } + + c.byteClassBuilder.setRange(b, b) + + if c.builder.asciiCaseInsensitive { + b := oppositeAsciiCase(b) + c.byteClassBuilder.setRange(b, b) + } + + next := c.nfa.state(prev).nextState(b) + + if next != failedStateID { + prev = next + } else { + next := c.addState(depth + 1) + c.nfa.state(prev).setNextState(b, next) + if c.builder.asciiCaseInsensitive { + b := oppositeAsciiCase(b) + c.nfa.state(prev).setNextState(b, next) + } + prev = next + } + } + c.nfa.state(prev).addMatch(i, len(pat)) + + if c.builder.prefilter { + c.prefilter.add(pat) + } + } +} + +const asciiCaseMask byte = 0b0010_0000 + +func toAsciiLowercase(b byte) byte { + return b | (1 * asciiCaseMask) +} + +func toAsciiUpper(b byte) byte { + b &= ^(1 * asciiCaseMask) + return b +} + +func oppositeAsciiCase(b byte) byte { + if 'A' <= b && b <= 'Z' { + return toAsciiLowercase(b) + } else if 'a' <= b && b <= 'z' { + return toAsciiUpper(b) + } + return b +} + +func (c *compiler) addState(depth int) stateID { + if depth < c.builder.denseDepth { + return c.nfa.addDenseState(depth) + + } + return c.nfa.addSparseState(depth) +} + +func newCompiler(builder iNFABuilder) compiler { + p := newPrefilterBuilder(builder.asciiCaseInsensitive) + + return compiler{ + builder: builder, + prefilter: p, + nfa: iNFA{ + matchKind: builder.matchKind, + startID: 2, + maxPatternLen: 0, + patternCount: 0, + heapBytes: 0, + prefilter: nil, + anchored: builder.anchored, + byteClasses: singletons(), + states: nil, + }, + byteClassBuilder: newByteClassBuilder(), + } +} + +type iNFABuilder struct { + denseDepth int + matchKind matchKind + prefilter bool + anchored bool + asciiCaseInsensitive bool +} + +func newNFABuilder(kind matchKind, asciiCaseInsensitive bool) *iNFABuilder { + return &iNFABuilder{ + denseDepth: 2, + matchKind: kind, + prefilter: true, + anchored: false, + asciiCaseInsensitive: asciiCaseInsensitive, + } +} + +func (b *iNFABuilder) build(patterns [][]byte) *iNFA { + c := newCompiler(*b) + return c.compile(patterns) +} + +type state struct { + trans transitions + fail stateID + matches []pattern + depth int +} + +func (s *state) heapBytes() int { + var i int + intSize := int(unsafe.Sizeof(i)) + return s.trans.heapBytes() + (len(s.matches) * (intSize * 2)) +} + +func (s *state) addMatch(patternID, patternLength int) { + s.matches = append(s.matches, pattern{ + PatternID: patternID, + PatternLength: patternLength, + }) +} + +func (s *state) isMatch() bool { + return len(s.matches) > 0 +} + +func (s *state) getLongestMatch() *int { + if len(s.matches) == 0 { + return nil + } + longest := s.matches[0].PatternLength + return &longest +} + +func (s *state) nextState(input byte) stateID { + return s.trans.nextState(input) +} + +func (s *state) setNextState(input byte, next stateID) { + s.trans.setNextState(input, next) +} + +type transitions struct { + sparse *sparse + dense *dense +} + +func sparseIter(trans []innerSparse, f func(*next)) { + var byte16 uint16 + + for _, tr := range trans { + for byte16 < uint16(tr.b) { + f(&next{ + key: byte(byte16), + id: failedStateID, + }) + byte16 += 1 + } + f(&next{ + key: tr.b, + id: tr.s, + }) + byte16 += 1 + } + + for b := byte16; b < 256; b++ { + f(&next{ + key: byte(b), + id: failedStateID, + }) + } +} + +func (t *transitions) iterAll(byteClasses *byteClasses, f func(tr *next)) { + if byteClasses.isSingleton() { + if t.sparse != nil { + sparseIter(t.sparse.inner, f) + } + + if t.dense != nil { + for b := 0; b < 256; b++ { + f(&next{ + key: byte(b), + id: t.dense.inner[b], + }) + } + } + } else { + if t.sparse != nil { + var lastClass *byte + + sparseIter(t.sparse.inner, func(n *next) { + class := byteClasses.bytes[n.key] + + if lastClass == nil || *lastClass != class { + cc := class + lastClass = &cc + f(n) + } + }) + } + + if t.dense != nil { + bcr := byteClassRepresentatives{ + classes: byteClasses, + bbyte: 0, + lastClass: nil, + } + + for n := bcr.next(); n != nil; n = bcr.next() { + f(&next{ + key: *n, + id: t.dense.inner[*n], + }) + } + } + } + +} + +func (t *transitions) heapBytes() int { + var i int + intSize := int(unsafe.Sizeof(i)) + if t.sparse != nil { + return len(t.sparse.inner) * (2 * intSize) + } + return len(t.dense.inner) * intSize +} + +func (t *transitions) nextState(input byte) stateID { + if t.sparse != nil { + for _, sp := range t.sparse.inner { + if sp.b == input { + return sp.s + } + } + return failedStateID + } + return t.dense.inner[input] +} + +func (t *transitions) setNextState(input byte, next stateID) { + if t.sparse != nil { + idx := sort.Search(len(t.sparse.inner), func(i int) bool { + return t.sparse.inner[i].b >= input + }) + + if idx < len(t.sparse.inner) && t.sparse.inner[idx].b == input { + t.sparse.inner[idx].s = next + } else { + if len(t.sparse.inner) > 0 { + is := innerSparse{ + b: input, + s: next, + } + if idx == len(t.sparse.inner) { + t.sparse.inner = append(t.sparse.inner, is) + } else { + t.sparse.inner = append( + t.sparse.inner[:idx+1], + t.sparse.inner[idx:]...) + t.sparse.inner[idx] = is + } + } else { + t.sparse.inner = []innerSparse{ + { + b: input, + s: next, + }, + } + } + } + return + } + t.dense.inner[int(input)] = next +} + +func newDense() dense { + return dense{inner: make([]stateID, 256)} +} + +type dense struct { + inner []stateID +} + +type innerSparse struct { + b byte + s stateID +} + +type sparse struct { + inner []innerSparse +} diff --git a/nstr/ac/prefilter.go b/nstr/ac/prefilter.go new file mode 100644 index 0000000..55b7876 --- /dev/null +++ b/nstr/ac/prefilter.go @@ -0,0 +1,601 @@ +package ac + +import ( + "math" +) + +type startBytesThree struct { + byte1 byte + byte2 byte + byte3 byte +} + +func (s *startBytesThree) NextCandidate(_ *prefilterState, haystack []byte, at int) (interface{}, candidateType) { + for i, b := range haystack[at:] { + if s.byte1 == b || s.byte2 == b || s.byte3 == b { + return at + i, possibleStartOfMatchCandidate + } + } + return nil, noneCandidate +} + +func (s *startBytesThree) HeapBytes() int { + return 0 +} + +func (s *startBytesThree) ReportsFalsePositives() bool { + return true +} + +func (s *startBytesThree) LooksForNonStartOfMatch() bool { + return false +} + +func (s *startBytesThree) clone() prefilter { + if s == nil { + return nil + } + u := *s + return &u +} + +type startBytesTwo struct { + byte1 byte + byte2 byte +} + +func (s *startBytesTwo) NextCandidate(_ *prefilterState, haystack []byte, at int) (interface{}, candidateType) { + for i, b := range haystack[at:] { + if s.byte1 == b || s.byte2 == b { + return at + i, possibleStartOfMatchCandidate + } + } + return nil, noneCandidate +} + +func (s *startBytesTwo) HeapBytes() int { + return 0 +} + +func (s *startBytesTwo) ReportsFalsePositives() bool { + return true +} + +func (s *startBytesTwo) LooksForNonStartOfMatch() bool { + return false +} + +func (s *startBytesTwo) clone() prefilter { + if s == nil { + return nil + } + u := *s + return &u +} + +type startBytesOne struct { + byte1 byte +} + +func (s *startBytesOne) NextCandidate(_ *prefilterState, haystack []byte, at int) (interface{}, candidateType) { + for i, b := range haystack[at:] { + if s.byte1 == b { + return at + i, possibleStartOfMatchCandidate + } + } + return nil, noneCandidate +} + +func (s *startBytesOne) HeapBytes() int { + return 0 +} + +func (s *startBytesOne) ReportsFalsePositives() bool { + return true +} + +func (s *startBytesOne) LooksForNonStartOfMatch() bool { + return false +} + +func (s *startBytesOne) clone() prefilter { + if s == nil { + return nil + } + u := *s + return &u +} + +type byteSet [256]bool + +func (b *byteSet) contains(bb byte) bool { + return b[int(bb)] +} + +func (b *byteSet) insert(bb byte) bool { + n := !b.contains(bb) + b[int(bb)] = true + return n +} + +type rareByteOffset struct { + max byte +} + +type rareByteOffsets struct { + rbo [256]rareByteOffset +} + +func (r *rareByteOffsets) set(b byte, off rareByteOffset) { + m := byte(max(int(r.rbo[int(b)].max), int(off.max))) + r.rbo[int(b)].max = m +} + +type prefilterBuilder struct { + count int + asciiCaseInsensitive bool + startBytes startBytesBuilder + rareBytes rareBytesBuilder +} + +func (p *prefilterBuilder) build() prefilter { + startBytes := p.startBytes.build() + rareBytes := p.rareBytes.build() + + switch true { + case startBytes != nil && rareBytes != nil: + hasFewerBytes := p.startBytes.count < p.rareBytes.count + + hasRarerBytes := p.startBytes.rankSum <= p.rareBytes.rankSum+50 + if hasFewerBytes || hasRarerBytes { + return startBytes + } else { + return rareBytes + } + case startBytes != nil: + return startBytes + case rareBytes != nil: + return rareBytes + case p.asciiCaseInsensitive: + return nil + default: + return nil + } +} + +func (p *prefilterBuilder) add(bytes []byte) { + p.count += 1 + p.startBytes.add(bytes) + p.rareBytes.add(bytes) +} + +func newPrefilterBuilder(asciiCaseInsensitive bool) prefilterBuilder { + return prefilterBuilder{ + count: 0, + asciiCaseInsensitive: asciiCaseInsensitive, + startBytes: newStartBytesBuilder(asciiCaseInsensitive), + rareBytes: newRareBytesBuilder(asciiCaseInsensitive), + } +} + +type rareBytesBuilder struct { + asciiCaseInsensitive bool + rareSet byteSet + byteOffsets rareByteOffsets + available bool + count int + rankSum uint16 +} + +type rareBytesOne struct { + byte1 byte + offset rareByteOffset +} + +func (r *rareBytesOne) NextCandidate(state *prefilterState, haystack []byte, at int) (interface{}, candidateType) { + for i, b := range haystack[at:] { + if r.byte1 == b { + pos := at + i + state.lastScanAt = pos + r := pos - int(r.offset.max) + if r < 0 { + r = 0 + } + + if at > r { + r = at + } + return r, possibleStartOfMatchCandidate + } + } + return nil, noneCandidate +} + +func (r *rareBytesOne) HeapBytes() int { + return 0 +} + +func (r *rareBytesOne) ReportsFalsePositives() bool { + return true +} + +func (r *rareBytesOne) LooksForNonStartOfMatch() bool { + return true +} + +func (r *rareBytesOne) clone() prefilter { + if r == nil { + return nil + } + u := *r + return &u +} + +type rareBytesTwo struct { + offsets rareByteOffsets + byte1 byte + byte2 byte +} + +func (r *rareBytesTwo) NextCandidate(state *prefilterState, haystack []byte, at int) (interface{}, candidateType) { + for i, b := range haystack[at:] { + if r.byte1 == b || r.byte2 == b { + pos := at + i + state.updateAt(pos) + r := pos - int(r.offsets.rbo[haystack[pos]].max) + if r < 0 { + r = 0 + } + + if at > r { + r = at + } + return r, possibleStartOfMatchCandidate + } + } + return nil, noneCandidate +} + +func (r *rareBytesTwo) HeapBytes() int { + return 0 +} + +func (r *rareBytesTwo) ReportsFalsePositives() bool { + return true +} + +func (r *rareBytesTwo) LooksForNonStartOfMatch() bool { + return true +} + +func (r *rareBytesTwo) clone() prefilter { + if r == nil { + return nil + } + u := *r + return &u +} + +type rareBytesThree struct { + offsets rareByteOffsets + byte1 byte + byte2 byte + byte3 byte +} + +func (r *rareBytesThree) NextCandidate(state *prefilterState, haystack []byte, at int) (interface{}, candidateType) { + for i, b := range haystack[at:] { + if r.byte1 == b || r.byte2 == b || r.byte3 == b { + pos := at + i + state.updateAt(pos) + r := pos - int(r.offsets.rbo[haystack[pos]].max) + if r < 0 { + r = 0 + } + + if at > r { + r = at + } + return r, possibleStartOfMatchCandidate + } + } + return nil, noneCandidate +} + +func (r *rareBytesThree) HeapBytes() int { + return 0 +} + +func (r *rareBytesThree) ReportsFalsePositives() bool { + return true +} + +func (r *rareBytesThree) LooksForNonStartOfMatch() bool { + return true +} + +func (r *rareBytesThree) clone() prefilter { + if r == nil { + return nil + } + u := *r + return &u +} + +func (r *rareBytesBuilder) build() prefilter { + if !r.available || r.count > 3 { + return nil + } + var length int + bytes := [3]byte{} + + for b := 0; b <= 255; b++ { + if r.rareSet.contains(byte(b)) { + bytes[length] = byte(b) + length += 1 + } + } + + switch length { + case 0: + return nil + case 1: + return &rareBytesOne{ + byte1: bytes[0], + offset: r.byteOffsets.rbo[bytes[0]], + } + case 2: + return &rareBytesTwo{ + offsets: r.byteOffsets, + byte1: bytes[0], + byte2: bytes[1], + } + case 3: + return &rareBytesThree{ + offsets: r.byteOffsets, + byte1: bytes[0], + byte2: bytes[1], + byte3: bytes[2], + } + default: + return nil + } +} + +func (r *rareBytesBuilder) add(bytes []byte) { + if !r.available { + return + } + + if r.count > 3 { + r.available = false + return + } + + if len(bytes) >= 256 { + r.available = false + return + } + + if len(bytes) == 0 { + return + } + + rarest1, rarest2 := bytes[0], freqRank(bytes[0]) + found := false + + for pos, b := range bytes { + r.setOffset(pos, b) + if found { + continue + } + if r.rareSet.contains(b) { + found = true + } + rank := freqRank(b) + if rank < rarest2 { + rarest1 = b + rarest2 = rank + } + + if !found { + r.addRareByte(rarest1) + } + } +} + +func (r *rareBytesBuilder) addRareByte(b byte) { + r.addOneRareByte(b) + if r.asciiCaseInsensitive { + r.addOneRareByte(oppositeAsciiCase(b)) + } +} + +func (r *rareBytesBuilder) addOneRareByte(b byte) { + if r.rareSet.insert(b) { + r.count += 1 + r.rankSum += uint16(freqRank(b)) + } +} + +func newRareByteOffset(i int) rareByteOffset { + if i > math.MaxUint8 { + return rareByteOffset{max: 0} + } + b := byte(i) + return rareByteOffset{max: b} +} + +func (r *rareBytesBuilder) setOffset(pos int, b byte) { + offset := newRareByteOffset(pos) + r.byteOffsets.set(b, offset) + + if r.asciiCaseInsensitive { + r.byteOffsets.set(oppositeAsciiCase(b), offset) + } +} + +func newRareBytesBuilder(asciiCaseInsensitive bool) rareBytesBuilder { + return rareBytesBuilder{ + asciiCaseInsensitive: asciiCaseInsensitive, + rareSet: byteSet{}, + byteOffsets: rareByteOffsets{}, + available: true, + count: 0, + rankSum: 0, + } +} + +type startBytesBuilder struct { + asciiCaseInsensitive bool + byteSet []bool + count int + rankSum uint16 +} + +func (s *startBytesBuilder) build() prefilter { + if s.count > 3 { + return nil + } + var length int + bytes := [3]byte{} + + for b := 0; b < 256; b++ { + //todo case insensitive is not set in byteSet + if !s.byteSet[b] { + continue + } + if b > 0x7F { + return nil + } + bytes[length] = byte(b) + length += 1 + } + + switch length { + case 0: + return nil + case 1: + return &startBytesOne{byte1: bytes[0]} + case 2: + return &startBytesTwo{ + byte1: bytes[0], + byte2: bytes[1], + } + case 3: + return &startBytesThree{ + byte1: bytes[0], + byte2: bytes[1], + byte3: bytes[2], + } + default: + return nil + } +} + +func (s *startBytesBuilder) add(bytes []byte) { + if s.count > 3 || len(bytes) == 0 { + return + } + + b := bytes[0] + + s.addOneByte(b) + if s.asciiCaseInsensitive { + s.addOneByte(oppositeAsciiCase(b)) + } +} + +func (s *startBytesBuilder) addOneByte(b byte) { + if !s.byteSet[int(b)] { + s.byteSet[int(b)] = true + s.count += 1 + s.rankSum += uint16(freqRank(b)) + } +} + +func freqRank(b byte) byte { + return byteFrequencies[int(b)] +} + +func newStartBytesBuilder(asciiCaseInsensitive bool) startBytesBuilder { + return startBytesBuilder{ + asciiCaseInsensitive: asciiCaseInsensitive, + byteSet: make([]bool, 256), + count: 0, + rankSum: 0, + } +} + +const minSkips int = 40 +const minAvgFactor int = 2 + +type prefilterState struct { + skips int + skipped int + maxMatchLen int + inert bool + lastScanAt int +} + +func (p *prefilterState) updateAt(at int) { + if at > p.lastScanAt { + p.lastScanAt = at + } +} + +func (p *prefilterState) IsEffective(at int) bool { + if p.inert || at < p.lastScanAt { + return false + } + + if p.skips < minSkips { + return true + } + + minAvg := minAvgFactor * p.maxMatchLen + + if p.skipped >= minAvg*p.skips { + return true + } + + p.inert = true + return false +} + +func (p *prefilterState) updateSkippedBytes(skipped int) { + p.skips += 1 + p.skipped += skipped +} + +type candidateType uint + +const ( + noneCandidate candidateType = iota + matchCandidate + possibleStartOfMatchCandidate +) + +type prefilter interface { + NextCandidate(state *prefilterState, haystack []byte, at int) (interface{}, candidateType) + HeapBytes() int + ReportsFalsePositives() bool + LooksForNonStartOfMatch() bool + clone() prefilter +} + +func nextPrefilter(state *prefilterState, prefilter prefilter, haystack []byte, at int) (interface{}, candidateType) { + candidate, typ := prefilter.NextCandidate(state, haystack, at) + + switch typ { + case noneCandidate: + state.updateSkippedBytes(len(haystack) - at) + case matchCandidate: + m := candidate.(*Match) + state.updateSkippedBytes(m.Start() - at) + case possibleStartOfMatchCandidate: + i := candidate.(int) + state.updateSkippedBytes(i - at) + } + return candidate, typ +} diff --git a/nstr/ac/util.go b/nstr/ac/util.go new file mode 100644 index 0000000..def96e9 --- /dev/null +++ b/nstr/ac/util.go @@ -0,0 +1,9 @@ +package ac + +// max return max value +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/nstr/check.go b/nstr/check.go new file mode 100644 index 0000000..5c92e06 --- /dev/null +++ b/nstr/check.go @@ -0,0 +1,406 @@ +package nstr + +import ( + "encoding/json" + "net" + "net/url" + "reflect" + "regexp" + "strconv" + "strings" + "unicode" +) + +var ( + alphaMatcher = regexp.MustCompile(`^[a-zA-Z]+$`) + letterRegexMatcher = regexp.MustCompile(`[a-zA-Z]`) + intStrMatcher = regexp.MustCompile(`^[\+-]?\d+$`) + urlMatcher = regexp.MustCompile(`^((ftp|http|https?):\/\/)?(\S+(:\S*)?@)?((([1-9]\d?|1\d\d|2[01]\d|22[0-3])(\.(1?\d{1,2}|2[0-4]\d|25[0-5])){2}(?:\.([0-9]\d?|1\d\d|2[0-4]\d|25[0-4]))|(([a-zA-Z0-9]+([-\.][a-zA-Z0-9]+)*)|((www\.)?))?(([a-z\x{00a1}-\x{ffff}0-9]+-?-?)*[a-z\x{00a1}-\x{ffff}0-9]+)(?:\.([a-z\x{00a1}-\x{ffff}]{2,}))?))(:(\d{1,5}))?((\/|\?|#)[^\s]*)?$`) + dnsMatcher = regexp.MustCompile(`^[a-zA-Z]([a-zA-Z0-9\-]+[\.]?)*[a-zA-Z0-9]$`) + emailMatcher = regexp.MustCompile(`\w+([-+.]\w+)*@\w+([-.]\w+)*\.\w+([-.]\w+)*`) + chineseMobileMatcher = regexp.MustCompile(`^1(?:3\d|4[4-9]|5[0-35-9]|6[67]|7[013-8]|8\d|9\d)\d{8}$`) + chineseIdMatcher = regexp.MustCompile(`^[1-9]\d{5}(18|19|20|21|22)\d{2}((0[1-9])|(1[0-2]))(([0-2][1-9])|10|20|30|31)\d{3}[0-9Xx]$`) + chineseMatcher = regexp.MustCompile("[\u4e00-\u9fa5]") + chinesePhoneMatcher = regexp.MustCompile(`\d{3}-\d{8}|\d{4}-\d{7}|\d{4}-\d{8}`) + creditCardMatcher = regexp.MustCompile(`^(?:4[0-9]{12}(?:[0-9]{3})?|5[1-5][0-9]{14}|(222[1-9]|22[3-9][0-9]|2[3-6][0-9]{2}|27[01][0-9]|2720)[0-9]{12}|6(?:011|5[0-9][0-9])[0-9]{12}|3[47][0-9]{13}|3(?:0[0-5]|[68][0-9])[0-9]{11}|(?:2131|1800|35\\d{3})\\d{11}|6[27][0-9]{14})$`) + base64Matcher = regexp.MustCompile(`^(?:[A-Za-z0-9+\\/]{4})*(?:[A-Za-z0-9+\\/]{2}==|[A-Za-z0-9+\\/]{3}=|[A-Za-z0-9+\\/]{4})$`) +) + +// IsString check if the value data type is string or not. +func IsString(v any) bool { + if v == nil { + return false + } + switch v.(type) { + case string: + return true + default: + return false + } +} + +// IsAlpha checks if the string contains only letters (a-zA-Z). +func IsAlpha(str string) bool { + return alphaMatcher.MatchString(str) +} + +// IsAllUpper check if the string is all upper case letters A-Z. +func IsAllUpper(str string) bool { + for _, r := range str { + if !unicode.IsUpper(r) { + return false + } + } + return str != "" +} + +// IsAllLower check if the string is all lower case letters a-z. +func IsAllLower(str string) bool { + for _, r := range str { + if !unicode.IsLower(r) { + return false + } + } + return str != "" +} + +// ContainUpper check if the string contain at least one upper case letter A-Z. +func ContainUpper(str string) bool { + for _, r := range str { + if unicode.IsUpper(r) && unicode.IsLetter(r) { + return true + } + } + return false +} + +// ContainLower check if the string contain at least one lower case letter a-z. +func ContainLower(str string) bool { + for _, r := range str { + if unicode.IsLower(r) && unicode.IsLetter(r) { + return true + } + } + return false +} + +// ContainLetter check if the string contain at least one letter. +func ContainLetter(str string) bool { + return letterRegexMatcher.MatchString(str) +} + +// IsJSON checks if the string is valid JSON. +func IsJSON(str string) bool { + return json.Valid([]byte(str)) +} + +// IsNumberStr check if the string can convert to a number. +func IsNumberStr(s string) bool { + return IsIntStr(s) || IsFloatStr(s) +} + +// IsFloatStr check if the string can convert to a float. +func IsFloatStr(str string) bool { + _, e := strconv.ParseFloat(str, 64) + return e == nil +} + +// IsIntStr check if the string can convert to a integer. +func IsIntStr(str string) bool { + return intStrMatcher.MatchString(str) +} + +// IsIp check if the string is an ip address. +func IsIp(ipstr string) bool { + ip := net.ParseIP(ipstr) + return ip != nil +} + +// IsIpV4 check if the string is a ipv4 address. +func IsIpV4(ipstr string) bool { + ip := net.ParseIP(ipstr) + if ip == nil { + return false + } + return strings.Contains(ipstr, ".") +} + +// IsIpV6 check if the string is a ipv6 address. +func IsIpV6(ipstr string) bool { + ip := net.ParseIP(ipstr) + if ip == nil { + return false + } + return strings.Contains(ipstr, ":") +} + +// IsPort check if the string is a valid net port. +func IsPort(str string) bool { + if i, err := strconv.ParseInt(str, 10, 64); err == nil && i > 0 && i < 65536 { + return true + } + return false +} + +// IsUrl check if the string is url. +func IsUrl(str string) bool { + if str == "" || len(str) >= 2083 || len(str) <= 3 || strings.HasPrefix(str, ".") { + return false + } + u, err := url.Parse(str) + if err != nil { + return false + } + if strings.HasPrefix(u.Host, ".") { + return false + } + if u.Host == "" && (u.Path != "" && !strings.Contains(u.Path, ".")) { + return false + } + + return urlMatcher.MatchString(str) +} + +// IsDns check if the string is dns. +func IsDns(dns string) bool { + return dnsMatcher.MatchString(dns) +} + +// IsEmail check if the string is a email address. +func IsEmail(email string) bool { + return emailMatcher.MatchString(email) +} + +// IsChineseMobile check if the string is chinese mobile number. +func IsChineseMobile(mobileNum string) bool { + return chineseMobileMatcher.MatchString(mobileNum) +} + +// IsChineseIdNum check if the string is chinese id card. +func IsChineseIdNum(id string) bool { + return chineseIdMatcher.MatchString(id) +} + +// ContainChinese check if the string contain mandarin chinese. +func ContainChinese(s string) bool { + return chineseMatcher.MatchString(s) +} + +// IsChinesePhone check if the string is chinese phone number. +// Valid chinese phone is xxx-xxxxxxxx or xxxx-xxxxxxx. +func IsChinesePhone(phone string) bool { + return chinesePhoneMatcher.MatchString(phone) +} + +// IsCreditCard check if the string is credit card. +func IsCreditCard(creditCart string) bool { + return creditCardMatcher.MatchString(creditCart) +} + +// IsBase64 check if the string is base64 string. +func IsBase64(base64 string) bool { + return base64Matcher.MatchString(base64) +} + +// IsEmptyString check if the string is empty. +func IsEmptyString(str string) bool { + return len(str) == 0 +} + +// IsRegexMatch check if the string match the regexp. +func IsRegexMatch(str, regex string) bool { + reg := regexp.MustCompile(regex) + return reg.MatchString(str) +} + +// IsStrongPassword check if the string is strong password, if len(password) is less than the length param, return false +// Strong password: alpha(lower+upper) + number + special chars(!@#$%^&*()?><). +func IsStrongPassword(password string, length int) bool { + if len(password) < length { + return false + } + var num, lower, upper, special bool + for _, r := range password { + switch { + case unicode.IsDigit(r): + num = true + case unicode.IsUpper(r): + upper = true + case unicode.IsLower(r): + lower = true + case unicode.IsSymbol(r), unicode.IsPunct(r): + special = true + } + } + + return num && lower && upper && special +} + +// IsWeakPassword check if the string is weak password +// only letter or only number or letter + number. +func IsWeakPassword(password string) bool { + var num, letter, special bool + for _, r := range password { + switch { + case unicode.IsDigit(r): + num = true + case unicode.IsLetter(r): + letter = true + case unicode.IsSymbol(r), unicode.IsPunct(r): + special = true + } + } + + return (num || letter) && !special +} + +// IsZeroValue checks if value is a zero value. +func IsZeroValue(value any) bool { + if value == nil { + return true + } + + rv := reflect.ValueOf(value) + if rv.Kind() == reflect.Ptr { + rv = rv.Elem() + } + + if !rv.IsValid() { + return true + } + + switch rv.Kind() { + case reflect.String: + return rv.Len() == 0 + case reflect.Bool: + return !rv.Bool() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return rv.Uint() == 0 + case reflect.Float32, reflect.Float64: + return rv.Float() == 0 + case reflect.Ptr, reflect.Chan, reflect.Func, reflect.Interface, reflect.Slice, reflect.Map: + return rv.IsNil() + } + + return reflect.DeepEqual(rv.Interface(), reflect.Zero(rv.Type()).Interface()) +} + +// ----- refer from github.com/yuin/goldmark/util + +// refer from github.com/yuin/goldmark/util +var spaceTable = [256]int8{0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + +// IsSpace returns true if the given character is a space, otherwise false. +func IsSpace(c byte) bool { return spaceTable[c] == 1 } + +// IsEmpty returns true if the given string is empty. +func IsEmpty(s string) bool { return len(s) == 0 } + +// IsBlank returns true if the given string is all space characters. +func IsBlank(s string) bool { return IsBlankBytes([]byte(s)) } + +// IsNotBlank returns true if the given string is not blank. +func IsNotBlank(s string) bool { return !IsBlankBytes([]byte(s)) } + +// IsBlankBytes returns true if the given []byte is all space characters. +func IsBlankBytes(bs []byte) bool { + for _, b := range bs { + if !IsSpace(b) { + return false + } + } + return true +} + +// IsSymbol reports whether the rune is a symbolic character. +func IsSymbol(r rune) bool { return unicode.IsSymbol(r) } + +// HasEmpty value for input strings +func HasEmpty(ss ...string) bool { + for _, s := range ss { + if s == "" { + return true + } + } + return false +} + +// IsAllEmpty for input strings +func IsAllEmpty(ss ...string) bool { + for _, s := range ss { + if s != "" { + return false + } + } + return true +} + +// ContainsByte in given string. +func ContainsByte(s string, c byte) bool { + return strings.IndexByte(s, c) >= 0 +} + +// ContainsOne substr(s) in the given string. alias of HasOneSub() +func ContainsOne(s string, subs []string) bool { return HasOneSub(s, subs) } + +// HasOneSub substr(s) in the given string. +func HasOneSub(s string, subs []string) bool { + for _, sub := range subs { + if strings.Contains(s, sub) { + return true + } + } + return false +} + +// ContainsAll substr(s) in the given string. alias of HasAllSubs() +func ContainsAll(s string, subs []string) bool { return HasAllSubs(s, subs) } + +// HasAllSubs all substr in the given string. +func HasAllSubs(s string, subs []string) bool { + for _, sub := range subs { + if !strings.Contains(s, sub) { + return false + } + } + return true +} + +// IsStartsOf alias of the HasOnePrefix +func IsStartsOf(s string, prefixes []string) bool { + return HasOnePrefix(s, prefixes) +} + +// HasOnePrefix the string start withs one of the subs +func HasOnePrefix(s string, prefixes []string) bool { + for _, prefix := range prefixes { + if strings.HasPrefix(s, prefix) { + return true + } + } + return false +} + +// HasPrefix substr in the given string. +func HasPrefix(s string, prefix string) bool { return strings.HasPrefix(s, prefix) } + +// IsStartOf alias of the strings.HasPrefix +func IsStartOf(s, prefix string) bool { return strings.HasPrefix(s, prefix) } + +// HasSuffix substr in the given string. +func HasSuffix(s string, suffix string) bool { return strings.HasSuffix(s, suffix) } + +// IsEndOf alias of the strings.HasSuffix +func IsEndOf(s, suffix string) bool { return strings.HasSuffix(s, suffix) } + +// HasOneSuffix the string end withs one of the subs +func HasOneSuffix(s string, suffixes []string) bool { + for _, suffix := range suffixes { + if strings.HasSuffix(s, suffix) { + return true + } + } + return false +} diff --git a/nstr/codec.go b/nstr/codec.go new file mode 100644 index 0000000..c020066 --- /dev/null +++ b/nstr/codec.go @@ -0,0 +1,71 @@ +package nstr + +import ( + "golang.org/x/text/encoding/simplifiedchinese" + "unicode/utf8" +) + +const ( + GBK string = "GBK" + UTF8 string = "UTF8" + UNKNOWN string = "UNKNOWN" +) + +// Charset 获取字符的编码类型 +// 需要说明的是,IsGBK()是通过双字节是否落在gbk的编码范围内实现的, +// 而utf-8编码格式的每个字节都是落在gbk的编码范围内, +// 所以只有先调用utf8.Valid() 先判断不是utf-8编码,再调用IsGBK()才有意义 +func Charset(data []byte) string { + if utf8.Valid(data) { + return UTF8 + } else if IsGBK(data) { + return GBK + } else { + return UNKNOWN + } +} + +// IsGBK 判断字符是否是 GBK 编码 +// 需要说明的是,IsGBK()是通过双字节是否落在gbk的编码范围内实现的, +// 而utf-8编码格式的每个字节都是落在gbk的编码范围内, +// 所以只有先调用utf8.Valid() 先判断不是utf-8编码,再调用IsGBK()才有意义 +// +// usage +// data := []byte("你好") +// if utf8.Valid(data) { +// fmt.Println("data encoding is utf-8") +// }else if(IsGBK(data)) { +// fmt.Println("data encoding is GBK") +// } +func IsGBK(data []byte) bool { + length := len(data) + i := 0 + for i < length { + if data[i] <= 0x7f { + // 编码0~127,只有一个字节的编码,兼容ASCII码 + i++ + continue + } else { + //大于127的使用双字节编码,落在gbk编码范围内的字符 + if data[i] >= 0x81 && + data[i] <= 0xfe && + data[i+1] >= 0x40 && + data[i+1] <= 0xfe && + data[i+1] != 0xf7 { + i += 2 + continue + } else { + return false + } + } + } + return true +} + +func ToGBK(data []byte) ([]byte, error) { + transBytes, err := simplifiedchinese.GB18030.NewDecoder().Bytes(data) + if err != nil { + return data, err + } + return transBytes, nil +} diff --git a/nstr/convert.go b/nstr/convert.go new file mode 100644 index 0000000..bdf2e3d --- /dev/null +++ b/nstr/convert.go @@ -0,0 +1,311 @@ +package nstr + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/internal/convert" + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nmath" + "reflect" + "strconv" + "strings" + "time" + "unsafe" +) + +// Quote alias of strings.Quote +func Quote(s string) string { return strconv.Quote(s) } + +// Unquote remove start and end quotes by single-quote or double-quote +// +// tip: strconv.Unquote cannot unquote single-quote +func Unquote(s string) string { + ln := len(s) + if ln < 2 { + return s + } + + qs, qe := s[0], s[ln-1] + + var valid bool + if qs == '"' && qe == '"' { + valid = true + } else if qs == '\'' && qe == '\'' { + valid = true + } + + if valid { + s = s[1 : ln-1] // exclude quotes + } + // strconv.Unquote cannot unquote single-quote + // if ns, err := strconv.Unquote(s); err == nil { + // return ns + // } + return s +} + +// JoinAny type to string +func JoinAny(sep string, parts ...any) string { + ss := make([]string, 0, len(parts)) + for _, part := range parts { + ss = append(ss, SafeString(part)) + } + + return strings.Join(ss, sep) +} + +/************************************************************* + * convert value to string + *************************************************************/ + +// ToString convert value to string, return error on failed +func ToString(val any) (string, error) { + return AnyToString(val, true) +} + +// SafeString convert value to string, will ignore error +func SafeString(in any) string { + val, _ := AnyToString(in, false) + return val +} + +// MustString convert value to string, will panic on error +func MustString(in any) string { + val, err := AnyToString(in, false) + if err != nil { + panic(err) + } + return val +} + +// AnyToString convert value to string. +// +// For defaultAsErr: +// +// - False will use fmt.Sprint convert complex type +// - True will return error on fail. +func AnyToString(val any, defaultAsErr bool) (str string, err error) { + if val == nil { + return + } + + switch value := val.(type) { + case int: + str = strconv.Itoa(value) + case int8: + str = strconv.Itoa(int(value)) + case int16: + str = strconv.Itoa(int(value)) + case int32: // same as `rune` + str = strconv.Itoa(int(value)) + case int64: + str = strconv.FormatInt(value, 10) + case uint: + str = strconv.FormatUint(uint64(value), 10) + case uint8: + str = strconv.FormatUint(uint64(value), 10) + case uint16: + str = strconv.FormatUint(uint64(value), 10) + case uint32: + str = strconv.FormatUint(uint64(value), 10) + case uint64: + str = strconv.FormatUint(value, 10) + case float32: + str = strconv.FormatFloat(float64(value), 'f', -1, 32) + case float64: + str = strconv.FormatFloat(value, 'f', -1, 64) + case bool: + str = strconv.FormatBool(value) + case string: + str = value + case []byte: + str = string(value) + case time.Duration: + str = strconv.FormatInt(int64(value), 10) + case fmt.Stringer: + str = value.String() + default: + if defaultAsErr { + err = ndef.ErrConvType + } else { + str = fmt.Sprint(value) + } + } + return +} + +/************************************************************* + * convert string value to byte + * refer from https://github.com/valyala/fastjson/blob/master/util.go + *************************************************************/ + +// Byte2str convert bytes to string +func Byte2str(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// Byte2string convert bytes to string +func Byte2string(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// ToBytes convert string to bytes +func ToBytes(s string) (b []byte) { + strh := (*reflect.StringHeader)(unsafe.Pointer(&s)) + + sh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + sh.Data = strh.Data + sh.Len = strh.Len + sh.Cap = strh.Len + return b +} + +/************************************************************* + * convert string value to bool + *************************************************************/ + +// ToBool convert string to bool +func ToBool(s string) (bool, error) { + return Bool(s) +} + +// QuietBool convert to bool, will ignore error +func QuietBool(s string) bool { + val, _ := Bool(strings.TrimSpace(s)) + return val +} + +// MustBool convert, will panic on error +func MustBool(s string) bool { + val, err := Bool(strings.TrimSpace(s)) + if err != nil { + panic(err) + } + return val +} + +// Bool parse string to bool. like strconv.ParseBool() +func Bool(s string) (bool, error) { + return convert.StrToBool(s) +} + +/************************************************************* + * convert string value to int, float + *************************************************************/ + +// Int convert string to int, alias of ToInt() +func Int(s string) (int, error) { + return strconv.Atoi(strings.TrimSpace(s)) +} + +// ToInt convert string to int, return error on fail +func ToInt(s string) (int, error) { + return strconv.Atoi(strings.TrimSpace(s)) +} + +// Int2 convert string to int, will ignore error +func Int2(s string) int { + val, _ := ToInt(s) + return val +} + +// QuietInt convert string to int, will ignore error +func QuietInt(s string) int { + val, _ := ToInt(s) + return val +} + +// MustInt convert string to int, will panic on error +func MustInt(s string) int { + return IntOrPanic(s) +} + +// IntOrPanic convert value to int, will panic on error +func IntOrPanic(s string) int { + val, err := ToInt(s) + if err != nil { + panic(err) + } + return val +} + +// Int64 convert string to int, will ignore error +func Int64(s string) int64 { + val, _ := Int64OrErr(s) + return val +} + +// QuietInt64 convert string to int, will ignore error +func QuietInt64(s string) int64 { + val, _ := Int64OrErr(s) + return val +} + +// ToInt64 convert string to int, return error on fail +func ToInt64(s string) (int64, error) { + return strconv.ParseInt(s, 10, 0) +} + +// Int64OrErr convert string to int, return error on fail +func Int64OrErr(s string) (int64, error) { + return strconv.ParseInt(s, 10, 0) +} + +// MustInt64 convert value to int, will panic on error +func MustInt64(s string) int64 { + return Int64OrPanic(s) +} + +// Int64OrPanic convert value to int, will panic on error +func Int64OrPanic(s string) int64 { + val, err := strconv.ParseInt(s, 10, 0) + if err != nil { + panic(err) + } + return val +} + +/************************************************************* + * convert string value to int/string slice, time.Time + *************************************************************/ + +// Ints alias of the ToIntSlice(). default sep is comma(,) +func Ints(s string, sep ...string) []int { + ints, _ := ToIntSlice(s, sep...) + return ints +} + +// ToInts alias of the ToIntSlice(). default sep is comma(,) +func ToInts(s string, sep ...string) ([]int, error) { return ToIntSlice(s, sep...) } + +// ToIntSlice split string to slice and convert item to int. +// +// Default sep is comma(,) +func ToIntSlice(s string, sep ...string) (ints []int, err error) { + ss := ToSlice(s, sep...) + for _, item := range ss { + iVal, err := nmath.ToInt(item) + if err != nil { + return []int{}, err + } + + ints = append(ints, iVal) + } + return +} + +// ToArray alias of the ToSlice() +func ToArray(s string, sep ...string) []string { return ToSlice(s, sep...) } + +// Strings alias of the ToSlice() +func Strings(s string, sep ...string) []string { return ToSlice(s, sep...) } + +// ToStrings alias of the ToSlice() +func ToStrings(s string, sep ...string) []string { return ToSlice(s, sep...) } + +// ToSlice split string to array. +func ToSlice(s string, sep ...string) []string { + if len(sep) > 0 { + return Split(s, sep[0]) + } + return Split(s, ",") +} diff --git a/nstr/filter.go b/nstr/filter.go new file mode 100644 index 0000000..1e72534 --- /dev/null +++ b/nstr/filter.go @@ -0,0 +1,57 @@ +package nstr + +import "strings" + +/************************************************************* + * String filtering + *************************************************************/ + +// Trim string. if cutSet is empty, will trim SPACE. +func Trim(s string, cutSet ...string) string { + if ln := len(cutSet); ln > 0 && cutSet[0] != "" { + if ln == 1 { + return strings.Trim(s, cutSet[0]) + } + + return strings.Trim(s, strings.Join(cutSet, "")) + } + + return strings.TrimSpace(s) +} + +// Ltrim alias of TrimLeft +func Ltrim(s string, cutSet ...string) string { return TrimLeft(s, cutSet...) } + +// LTrim alias of TrimLeft +func LTrim(s string, cutSet ...string) string { return TrimLeft(s, cutSet...) } + +// TrimLeft char in the string. if cutSet is empty, will trim SPACE. +func TrimLeft(s string, cutSet ...string) string { + if ln := len(cutSet); ln > 0 && cutSet[0] != "" { + if ln == 1 { + return strings.TrimLeft(s, cutSet[0]) + } + + return strings.TrimLeft(s, strings.Join(cutSet, "")) + } + + return strings.TrimLeft(s, " ") +} + +// Rtrim alias of TrimRight +func Rtrim(s string, cutSet ...string) string { return TrimRight(s, cutSet...) } + +// RTrim alias of TrimRight +func RTrim(s string, cutSet ...string) string { return TrimRight(s, cutSet...) } + +// TrimRight char in the string. if cutSet is empty, will trim SPACE. +func TrimRight(s string, cutSet ...string) string { + if ln := len(cutSet); ln > 0 && cutSet[0] != "" { + if ln == 1 { + return strings.TrimRight(s, cutSet[0]) + } + return strings.TrimRight(s, strings.Join(cutSet, "")) + } + + return strings.TrimRight(s, " ") +} diff --git a/nstr/match.go b/nstr/match.go new file mode 100644 index 0000000..2f775c9 --- /dev/null +++ b/nstr/match.go @@ -0,0 +1,131 @@ +package nstr + +import ( + "path" + "strings" +) + +// SimpleMatch all sub-string in the give text string. +// +// Difference the ContainsAll: +// +// - start with ^ for exclude contains check. +// - end with $ for check end with keyword. +func SimpleMatch(s string, keywords []string) bool { + for _, keyword := range keywords { + kln := len(keyword) + if kln == 0 { + continue + } + + // exclude + if kln > 1 && keyword[0] == '^' { + if strings.Contains(s, keyword[1:]) { + return false + } + continue + } + + // end with + if kln > 1 && keyword[kln-1] == '$' { + return strings.HasSuffix(s, keyword[:kln-1]) + } + + // include + if !strings.Contains(s, keyword) { + return false + } + } + return true +} + +// QuickMatch check for a string. pattern can be a sub string. +func QuickMatch(pattern, s string) bool { + if strings.ContainsRune(pattern, '*') { + return GlobMatch(pattern, s) + } + return strings.Contains(s, pattern) +} + +// PathMatch check for a string match the pattern. alias of the path.Match() +// +// TIP: `*` can match any char, not contain `/`. +func PathMatch(pattern, s string) bool { + ok, err := path.Match(pattern, s) + if err != nil { + ok = false + } + return ok +} + +// GlobMatch check for a string match the pattern. +// +// Difference with PathMatch() is: `*` can match any char, contain `/`. +func GlobMatch(pattern, s string) bool { + // replace `/` to `S` for path.Match + pattern = strings.Replace(pattern, "/", "S", -1) + s = strings.Replace(s, "/", "S", -1) + + ok, err := path.Match(pattern, s) + if err != nil { + ok = false + } + return ok +} + +// LikeMatch simple check for a string match the pattern. pattern like the SQL LIKE. +func LikeMatch(pattern, s string) bool { + ln := len(pattern) + if ln < 2 { + return false + } + + // eg `%abc` `%abc%` + if pattern[0] == '%' { + if ln > 2 && pattern[ln-1] == '%' { + return strings.Contains(s, pattern[1:ln-1]) + } else { + return strings.HasSuffix(s, pattern[1:]) + } + } + + // eg `abc%` + if pattern[ln-1] == '%' { + return strings.HasPrefix(s, pattern[:ln-1]) + } + return pattern == s +} + +// MatchNodePath check for a string match the pattern. +// +// Use on pattern: +// - `*` match any to sep +// - `**` match any to end. only allow at start or end on pattern. +// +// Example: +// +// strutil.MatchNodePath() +func MatchNodePath(pattern, s string, sep string) bool { + if pattern == "**" || pattern == s { + return true + } + if pattern == "" { + return len(s) == 0 + } + + if i := strings.Index(pattern, "**"); i >= 0 { + if i == 0 { // at start + return strings.HasSuffix(s, pattern[2:]) + } + return strings.HasPrefix(s, pattern[:len(pattern)-2]) + } + + pattern = strings.Replace(pattern, sep, "/", -1) + s = strings.Replace(s, sep, "/", -1) + + ok, err := path.Match(pattern, s) + if err != nil { + ok = false + } + return ok +} diff --git a/nstr/padding.go b/nstr/padding.go new file mode 100644 index 0000000..1972874 --- /dev/null +++ b/nstr/padding.go @@ -0,0 +1,130 @@ +package nstr + +import "fmt" + +// PosFlag type +type PosFlag uint8 + +// Position for padding/resize string +const ( + PosLeft PosFlag = iota + PosRight + PosMiddle +) + +/************************************************************* + * String padding operation + *************************************************************/ + +// Padding a string. +func Padding(s, pad string, length int, pos PosFlag) string { + diff := len(s) - length + if diff >= 0 { // do not need padding. + return s + } + + if pad == "" || pad == " " { + mark := "" + if pos == PosRight { // to right + mark = "-" + } + + // padding left: "%7s", padding right: "%-7s" + tpl := fmt.Sprintf("%s%d", mark, length) + return fmt.Sprintf(`%`+tpl+`s`, s) + } + + if pos == PosRight { // to right + return s + Repeat(pad, -diff) + } + return Repeat(pad, -diff) + s +} + +// PadLeft a string. +func PadLeft(s, pad string, length int) string { + return Padding(s, pad, length, PosLeft) +} + +// PadRight a string. +func PadRight(s, pad string, length int) string { + return Padding(s, pad, length, PosRight) +} + +// PadAround a string. Both left and right +func PadAround(s, pad string, length int) string { + return PadRight(PadLeft(s, pad, length-1), pad, length) +} + +// Resize a string by given length and align settings. padding space. +func Resize(s string, length int, align PosFlag) string { + diff := len(s) - length + if diff >= 0 { // do not need padding. + return s + } + + if align == PosMiddle { + strLn := len(s) + padLn := (length - strLn) / 2 + padStr := string(make([]byte, padLn)) + + if diff := length - padLn*2; diff > 0 { + s += " " + } + return padStr + s + padStr + } + + return Padding(s, " ", length, align) +} + +// PadChars padding a rune/byte to want length and with position flag +func PadChars[T byte | rune](cs []T, pad T, length int, pos PosFlag) []T { + ln := len(cs) + if ln >= length { + ns := make([]T, length) + copy(ns, cs[:length]) + return ns + } + + idx := length - ln + ns := make([]T, length) + ps := RepeatChars(pad, idx) + if pos == PosRight { + copy(ns, cs) + copy(ns[idx:], ps) + } else { // to left + copy(ns[:idx], ps) + copy(ns[idx:], cs) + } + + return ns +} + +// PadBytes padding a byte to want length and with position flag +func PadBytes(bs []byte, pad byte, length int, pos PosFlag) []byte { + return PadChars(bs, pad, length, pos) +} + +// PadBytesLeft a byte to want length +func PadBytesLeft(bs []byte, pad byte, length int) []byte { + return PadChars(bs, pad, length, PosLeft) +} + +// PadBytesRight a byte to want length +func PadBytesRight(bs []byte, pad byte, length int) []byte { + return PadChars(bs, pad, length, PosRight) +} + +// PadRunes padding a rune to want length and with position flag +func PadRunes(rs []rune, pad rune, length int, pos PosFlag) []rune { + return PadChars(rs, pad, length, pos) +} + +// PadRunesLeft a rune to want length +func PadRunesLeft(rs []rune, pad rune, length int) []rune { + return PadChars(rs, pad, length, PosLeft) +} + +// PadRunesRight a rune to want length +func PadRunesRight(rs []rune, pad rune, length int) []rune { + return PadChars(rs, pad, length, PosRight) +} diff --git a/nstr/parser.go b/nstr/parser.go new file mode 100644 index 0000000..c057b8f --- /dev/null +++ b/nstr/parser.go @@ -0,0 +1,140 @@ +package nstr + +import ( + "errors" + "git.noahlan.cn/noahlan/ntool/nbyte" + "strconv" + "strings" + "unicode" +) + +// ParseSizeOpt parse size expression options +type ParseSizeOpt struct { + // OneAsMax if only one size value, use it as max size. default is false + OneAsMax bool + // SepChar is the separator char for time range string. default is '~' + SepChar byte + // KeywordFn is the function for parse keyword time string. + KeywordFn func(string) (min, max uint64, err error) +} + +func ensureOpt(opt *ParseSizeOpt) *ParseSizeOpt { + if opt == nil { + opt = &ParseSizeOpt{SepChar: '~'} + } else { + if opt.SepChar == 0 { + opt.SepChar = '~' + } + } + return opt +} + +// ErrInvalidSizeExpr invalid size expression error +var ErrInvalidSizeExpr = errors.New("invalid size expr") + +// ParseSizeRange parse range size expression to min and max size. +// +// Expression format: +// +// "1KB~2MB" => 1KB to 2MB +// "-1KB" => <1KB +// "~1MB" => <1MB +// "< 1KB" => <1KB +// "1KB" => >1KB +// "1KB~" => >1KB +// ">1KB" => >1KB +// "+1KB" => >1KB +func ParseSizeRange(expr string, opt *ParseSizeOpt) (min, max uint64, err error) { + opt = ensureOpt(opt) + expr = strings.TrimSpace(expr) + if expr == "" { + err = ErrInvalidSizeExpr + return + } + + // parse size range. eg: "1KB~2MB" + if strings.IndexByte(expr, '~') > -1 { + s1, s2 := TrimCut(expr, "~") + if s1 != "" { + min, err = ToByteSize(s1) + if err != nil { + return + } + } + + if s2 != "" { + max, err = ToByteSize(s2) + } + return + } + + // parse single size. eg: "1KB" + if nbyte.IsDigit(expr[0]) { + min, err = ToByteSize(expr) + if err != nil { + return + } + if opt.OneAsMax { + max = min + } + return + } + + // parse with prefix. eg: "<1KB", ">= 1KB", "-1KB", "+1KB" + switch expr[0] { + case '<', '-': + max, err = ToByteSize(strings.Trim(expr[1:], " =")) + case '>', '+': + min, err = ToByteSize(strings.Trim(expr[1:], " =")) + default: + // parse keyword. eg: "small", "large" + if opt.KeywordFn != nil { + min, max, err = opt.KeywordFn(expr) + } else { + err = ErrInvalidSizeExpr + } + } + return +} + +// ToByteSize converts size string like 1GB/1g or 12mb/12M into an unsigned integer number of bytes +func ToByteSize(sizeStr string) (uint64, error) { + sizeStr = strings.TrimSpace(sizeStr) + lastPos := len(sizeStr) - 1 + if lastPos < 0 { + return 0, nil + } + + if sizeStr[lastPos] == 'b' || sizeStr[lastPos] == 'B' { + // last second char is k,m,g,t + lastSec := sizeStr[lastPos-1] + if lastSec > 'A' { + lastPos-- + } + } else if nbyte.IsDigit(sizeStr[lastPos]) { // not unit suffix. eg: 346 + return strconv.ParseUint(sizeStr, 10, 32) + } + + multiplier := float64(1) + switch unicode.ToLower(rune(sizeStr[lastPos])) { + case 'k': + multiplier = 1 << 10 + case 'm': + multiplier = 1 << 20 + case 'g': + multiplier = 1 << 30 + case 't': + multiplier = 1 << 40 + case 'p': + multiplier = 1 << 50 + default: // b + multiplier = 1 + } + + sizeNum := strings.TrimSpace(sizeStr[:lastPos]) + size, err := strconv.ParseFloat(sizeNum, 64) + if err != nil { + return 0, err + } + return uint64(size * multiplier), nil +} diff --git a/nstr/repeat.go b/nstr/repeat.go new file mode 100644 index 0000000..0de206b --- /dev/null +++ b/nstr/repeat.go @@ -0,0 +1,47 @@ +package nstr + +import "strings" + +/************************************************************* + * String repeat operation + *************************************************************/ + +// Repeat a string +func Repeat(s string, times int) string { + if times <= 0 { + return "" + } + if times == 1 { + return s + } + + ss := make([]string, 0, times) + for i := 0; i < times; i++ { + ss = append(ss, s) + } + + return strings.Join(ss, "") +} + +// RepeatRune repeat a rune char. +func RepeatRune(char rune, times int) []rune { + return RepeatChars(char, times) +} + +// RepeatBytes repeat a byte char. +func RepeatBytes(char byte, times int) []byte { + return RepeatChars(char, times) +} + +// RepeatChars repeat a byte char. +func RepeatChars[T byte | rune](char T, times int) []T { + if times <= 0 { + return make([]T, 0) + } + + chars := make([]T, 0, times) + for i := 0; i < times; i++ { + chars = append(chars, char) + } + return chars +} diff --git a/nstr/runes.go b/nstr/runes.go new file mode 100644 index 0000000..fa0d46d --- /dev/null +++ b/nstr/runes.go @@ -0,0 +1,78 @@ +package nstr + +import "unicode" + +// RuneIsLower checks if a character is lower case ('a' to 'z') +func RuneIsLower(c rune) bool { + return 'a' <= c && c <= 'z' +} + +// RuneToLower converts a character 'A' to 'Z' to its lower case +func RuneToLower(r rune) rune { + if r >= 'A' && r <= 'Z' { + return r + 32 + } + return r +} + +// RuneToLowerAll converts a character 'A' to 'Z' to its lower case +func RuneToLowerAll(rs []rune) []rune { + for i := range rs { + rs[i] = RuneToLower(rs[i]) + } + return rs +} + +// RuneIsUpper checks if a character is upper case ('A' to 'Z') +func RuneIsUpper(c rune) bool { + return 'A' <= c && c <= 'Z' +} + +// RuneToUpper converts a character 'a' to 'z' to its upper case +func RuneToUpper(r rune) rune { + if r >= 'a' && r <= 'z' { + return r - 32 + } + return r +} + +// RuneToUpperAll converts a character 'a' to 'z' to its upper case +func RuneToUpperAll(rs []rune) []rune { + for i := range rs { + rs[i] = RuneToUpper(rs[i]) + } + return rs +} + +// RuneIsDigit checks if a character is digit ('0' to '9') +func RuneIsDigit(r rune) bool { + return r >= '0' && r <= '9' +} + +// RuneIsLetter checks r is a letter but not CJK character. +func RuneIsLetter(r rune) bool { + if !unicode.IsLetter(r) { + return false + } + + switch { + // cjk char: /[\u3040-\u30ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff\uff66-\uff9f]/ + // hiragana and katakana (Japanese only) + case r >= '\u3034' && r < '\u30ff': + return false + // CJK unified ideographs extension A (Chinese, Japanese, and Korean) + case r >= '\u3400' && r < '\u4dbf': + return false + // CJK unified ideographs (Chinese, Japanese, and Korean) + case r >= '\u4e00' && r < '\u9fff': + return false + // CJK compatibility ideographs (Chinese, Japanese, and Korean) + case r >= '\uf900' && r < '\ufaff': + return false + // half-width katakana (Japanese only) + case r >= '\uff66' && r < '\uff9f': + return false + } + + return true +} diff --git a/nstr/split.go b/nstr/split.go new file mode 100644 index 0000000..950dede --- /dev/null +++ b/nstr/split.go @@ -0,0 +1,159 @@ +package nstr + +import "strings" + +/************************************************************* + * String split operation + *************************************************************/ + +// Cut same of the strings.Cut +func Cut(s, sep string) (before string, after string, found bool) { + if i := strings.Index(s, sep); i >= 0 { + return s[:i], s[i+len(sep):], true + } + return s, "", false +} + +// QuietCut always returns two substring. +func QuietCut(s, sep string) (before string, after string) { + before, after, _ = Cut(s, sep) + return +} + +// MustCut always returns two substring. +func MustCut(s, sep string) (before string, after string) { + var ok bool + before, after, ok = Cut(s, sep) + if !ok { + panic("cannot split input string to two nodes") + } + return +} + +// TrimCut always returns two substring and trim space for items. +func TrimCut(s, sep string) (string, string) { + before, after, _ := Cut(s, sep) + return strings.TrimSpace(before), strings.TrimSpace(after) +} + +// SplitKV split string to key and value. +func SplitKV(s, sep string) (string, string) { return TrimCut(s, sep) } + +// SplitValid string to slice. will trim each item and filter empty string node. +func SplitValid(s, sep string) (ss []string) { return Split(s, sep) } + +// Split string to slice. will trim each item and filter empty string node. +func Split(s, sep string) (ss []string) { + if s = strings.TrimSpace(s); s == "" { + return + } + + for _, val := range strings.Split(s, sep) { + if val = strings.TrimSpace(val); val != "" { + ss = append(ss, val) + } + } + return +} + +// SplitNValid string to slice. will filter empty string node. +func SplitNValid(s, sep string, n int) (ss []string) { return SplitN(s, sep, n) } + +// SplitN string to slice. will filter empty string node. +func SplitN(s, sep string, n int) (ss []string) { + if s = strings.TrimSpace(s); s == "" { + return + } + + rawList := strings.Split(s, sep) + for i, val := range rawList { + if val = strings.TrimSpace(val); val != "" { + if len(ss) == n-1 { + ss = append(ss, strings.TrimSpace(strings.Join(rawList[i:], sep))) + break + } + + ss = append(ss, val) + } + } + return +} + +// SplitTrimmed split string to slice. +// will trim space for each node, but not filter empty +func SplitTrimmed(s, sep string) (ss []string) { + if s = strings.TrimSpace(s); s == "" { + return + } + + for _, val := range strings.Split(s, sep) { + ss = append(ss, strings.TrimSpace(val)) + } + return +} + +// SplitNTrimmed split string to slice. +// will trim space for each node, but not filter empty +func SplitNTrimmed(s, sep string, n int) (ss []string) { + if s = strings.TrimSpace(s); s == "" { + return + } + + for _, val := range strings.SplitN(s, sep, n) { + ss = append(ss, strings.TrimSpace(val)) + } + return +} + +// Substring returns a substring of the specified length starting at the specified offset position. +// if length <= 0, return pos to end. +func Substring(s string, pos, length int) string { + runes := []rune(s) + strLn := len(runes) + + // pos is too large + if pos >= strLn { + return "" + } + + stopIdx := pos + length + if length == 0 || stopIdx > strLn { + stopIdx = strLn + } else if length < 0 { + stopIdx = strLn + length + } + + return string(runes[pos:stopIdx]) +} + +// SplitInlineComment for an inline text string. +func SplitInlineComment(val string, strict ...bool) (string, string) { + // strict check: must with space + if len(strict) > 0 && strict[0] { + if pos := strings.Index(val, " #"); pos > -1 { + return strings.TrimRight(val[0:pos], " "), val[pos+1:] + } + + if pos := strings.Index(val, " //"); pos > -1 { + return strings.TrimRight(val[0:pos], " "), val[pos+1:] + } + } else { + if pos := strings.IndexByte(val, '#'); pos > -1 { + return strings.TrimRight(val[0:pos], " "), val[pos:] + } + + if pos := strings.Index(val, "//"); pos > -1 { + return strings.TrimRight(val[0:pos], " "), val[pos:] + } + } + + return val, "" +} + +// FirstLine from command output +func FirstLine(output string) string { + if i := strings.IndexByte(output, '\n'); i >= 0 { + return output[0:i] + } + return output +} diff --git a/nstr/textutil/textutil.go b/nstr/textutil/textutil.go new file mode 100644 index 0000000..44f7a01 --- /dev/null +++ b/nstr/textutil/textutil.go @@ -0,0 +1,63 @@ +// Package textutil provide some extra text handle util +package textutil + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/nstr" + "strings" +) + +// ReplaceVars by regex replace given tpl vars. +// +// If format is empty, will use {const defaultVarFormat} +func ReplaceVars(text string, vars map[string]any, format string) string { + return NewVarReplacer(format).Replace(text, vars) +} + +// RenderSMap by regex replace given tpl vars. +// +// If format is empty, will use {const defaultVarFormat} +func RenderSMap(text string, vars map[string]string, format string) string { + return NewVarReplacer(format).RenderSimple(text, vars) +} + +// IsMatchAll keywords in the give text string. +// +// TIP: can use ^ for exclude match. +func IsMatchAll(s string, keywords []string) bool { + return nstr.SimpleMatch(s, keywords) +} + +// ParseInlineINI parse config string to string-map. it's like INI format contents. +// +// Examples: +// +// eg: "name=val0;shorts=i;required=true;desc=a message" +// => +// {name: val0, shorts: i, required: true, desc: a message} +func ParseInlineINI(tagVal string, keys ...string) (mp nmap.SMap, err error) { + ss := nstr.Split(tagVal, ";") + ln := len(ss) + if ln == 0 { + return + } + + mp = make(nmap.SMap, ln) + for _, s := range ss { + if !strings.ContainsRune(s, '=') { + err = fmt.Errorf("parse inline config error: must match `KEY=VAL`") + return + } + + key, val := nstr.TrimCut(s, "=") + if len(keys) > 0 && !narr.StringsHas(keys, key) { + err = fmt.Errorf("parse inline config error: invalid key name %q", key) + return + } + + mp[key] = val + } + return +} diff --git a/nstr/textutil/textutil_test.go b/nstr/textutil/textutil_test.go new file mode 100644 index 0000000..ed9603d --- /dev/null +++ b/nstr/textutil/textutil_test.go @@ -0,0 +1,147 @@ +package textutil_test + +import ( + "git.noahlan.cn/noahlan/ntool/nstr" + "git.noahlan.cn/noahlan/ntool/nstr/textutil" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestReplaceVars(t *testing.T) { + tplVars := map[string]any{ + "name": "inhere", + "key_01": "inhere", + "key-02": "inhere", + "info": map[string]any{"age": 230, "sex": "man"}, + } + + tests := []struct { + tplText string + want string + }{ + {"hi inhere", "hi inhere"}, + {"hi {{name}}", "hi inhere"}, + {"hi {{ name}}", "hi inhere"}, + {"hi {{name }}", "hi inhere"}, + {"hi {{ name }}", "hi inhere"}, + {"hi {{ key_01 }}", "hi inhere"}, + {"hi {{ key-02 }}", "hi inhere"}, + {"hi {{ info.age }}", "hi 230"}, + } + + for i, tt := range tests { + t.Run(nstr.JoinAny(" ", "case", i), func(t *testing.T) { + if got := textutil.ReplaceVars(tt.tplText, tplVars, ""); got != tt.want { + t.Errorf("ReplaceVars() = %v, want = %v", got, tt.want) + } + }) + } + + // custom format + assert.Equal(t, "hi inhere", textutil.ReplaceVars("hi {$name}", tplVars, "{$,}")) + assert.Equal(t, "hi inhere age is 230", textutil.ReplaceVars("hi $name age is $info.age", tplVars, "$,")) + assert.Equal(t, "hi {$name}", textutil.ReplaceVars("hi {$name}", nil, "{$,}")) +} + +func TestNewFullReplacer(t *testing.T) { + vp := textutil.NewFullReplacer("") + + tplVars := map[string]any{ + "name": "inhere", + "info": map[string]any{"age": 230, "sex": "man"}, + } + + tpl := "hi, {{ name }}, {{ age | 23 }}" + str := vp.Render(tpl, nil) + assert.Eq(t, "hi, {{ name }}, 23", str) + + str = vp.Render(tpl, tplVars) + assert.Eq(t, "hi, inhere, 23", str) + + vp.OnNotFound(func(name string) (val string, ok bool) { + if name == "name" { + return "tom", true + } + return + }) + str = vp.Render(tpl, nil) + assert.Eq(t, "hi, tom, 23", str) +} + +func TestRenderSMap(t *testing.T) { + tplVars := map[string]string{ + "name": "inhere", + "age": "234", + "key_01": "inhere", + "key-02": "inhere", + } + + tests := []struct { + tplText string + want string + }{ + {"hi inhere", "hi inhere"}, + {"hi {{name}}", "hi inhere"}, + {"hi {{ name}}", "hi inhere"}, + {"hi {{name }}", "hi inhere"}, + {"hi {{ name }}", "hi inhere"}, + {"hi {{ key_01 }}", "hi inhere"}, + {"hi {{ key-02 }}", "hi inhere"}, + } + + for i, tt := range tests { + t.Run(nstr.JoinAny(" ", "case", i), func(t *testing.T) { + if got := textutil.RenderSMap(tt.tplText, tplVars, ""); got != tt.want { + t.Errorf("RenderSMap() = %v, want = %v", got, tt.want) + } + }) + } + + // custom format + assert.Equal(t, "hi inhere", textutil.RenderSMap("hi {$name}", tplVars, "{$,}")) + assert.Equal(t, "hi inhere age is 234", textutil.RenderSMap("hi $name age is $age", tplVars, "$")) + assert.Equal(t, "hi inhere age is 234.", textutil.RenderSMap("hi $name age is $age.", tplVars, "$,")) + assert.Equal(t, "hi {$name}", textutil.RenderSMap("hi {$name}", nil, "{$,}")) +} + +func TestVarReplacer_ParseVars(t *testing.T) { + vp := textutil.NewVarReplacer("") + str := "hi {{ name }}, age {{age}}, age {{age }}" + ss := vp.ParseVars(str) + + assert.NotEmpty(t, ss) + assert.Len(t, ss, 2) + assert.Contains(t, ss, "name") + assert.Contains(t, ss, "age") + + tplVars := map[string]any{ + "name": "inhere", + "age": 234, + } + assert.Equal(t, "hi inhere, age 234, age 234", vp.Render(str, tplVars)) + vp.DisableFlatten() + assert.Equal(t, "hi inhere, age 234, age 234", vp.Render(str, tplVars)) +} + +func TestIsMatchAll(t *testing.T) { + str := "hi inhere, age is 120" + assert.True(t, textutil.IsMatchAll(str, []string{"hi", "inhere"})) + assert.False(t, textutil.IsMatchAll(str, []string{"hi", "^inhere"})) +} + +func TestParseInlineINI(t *testing.T) { + mp, err := textutil.ParseInlineINI("") + assert.NoErr(t, err) + assert.Empty(t, mp) + + mp, err = textutil.ParseInlineINI("default=inhere") + assert.NoErr(t, err) + assert.NotEmpty(t, mp) + assert.Eq(t, "inhere", mp.Str("default")) + + _, err = textutil.ParseInlineINI("string") + assert.ErrSubMsg(t, err, "parse inline config error: must") + + _, err = textutil.ParseInlineINI("name=n;default=inhere", "name") + assert.ErrSubMsg(t, err, "parse inline config error: invalid key name") +} diff --git a/nstr/textutil/var_replacer.go b/nstr/textutil/var_replacer.go new file mode 100644 index 0000000..e602fc6 --- /dev/null +++ b/nstr/textutil/var_replacer.go @@ -0,0 +1,212 @@ +package textutil + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" + "regexp" + "strings" +) + +const defaultVarFormat = "{{,}}" + +// FallbackFn type +type FallbackFn = func(name string) (val string, ok bool) + +// VarReplacer struct +type VarReplacer struct { + init bool + + Left, Right string + lLen, rLen int + + varReg *regexp.Regexp + // flatten sub map in vars + flatSubs bool + parseEnv bool + // support parse default value. eg: {{ name | inhere }} + parseDef bool + // keepMissVars list. default False: will clear on each replace + keepMissVars bool + // missing vars list + missVars []string + // NotFound handler + NotFound FallbackFn +} + +// NewVarReplacer instance +func NewVarReplacer(format string, opFns ...func(vp *VarReplacer)) *VarReplacer { + vp := &VarReplacer{flatSubs: true} + for _, fn := range opFns { + fn(vp) + } + return vp.WithFormat(format) +} + +// NewFullReplacer instance +func NewFullReplacer(format string) *VarReplacer { + return NewVarReplacer(format, func(vp *VarReplacer) { + vp.WithParseEnv().WithParseDefault().KeepMissingVars() + }) +} + +// DisableFlatten on the input vars map +func (r *VarReplacer) DisableFlatten() *VarReplacer { + r.flatSubs = false + return r +} + +// KeepMissingVars on the replacement handle +func (r *VarReplacer) KeepMissingVars() *VarReplacer { + r.keepMissVars = true + return r +} + +// WithParseDefault value on the input template contents +func (r *VarReplacer) WithParseDefault() *VarReplacer { + r.parseDef = true + return r +} + +// WithParseEnv on the input vars value +func (r *VarReplacer) WithParseEnv() *VarReplacer { + r.parseEnv = true + return r +} + +// OnNotFound var handle +func (r *VarReplacer) OnNotFound(fn FallbackFn) *VarReplacer { + r.NotFound = fn + return r +} + +// WithFormat custom var template +func (r *VarReplacer) WithFormat(format string) *VarReplacer { + r.Left, r.Right = nstr.QuietCut(nstr.OrElse(format, defaultVarFormat), ",") + r.Init() + return r +} + +// Init var matcher +func (r *VarReplacer) Init() *VarReplacer { + if !r.init { + r.lLen, r.rLen = len(r.Left), len(r.Right) + if r.Right != "" { + r.varReg = regexp.MustCompile(regexp.QuoteMeta(r.Left) + `([\w\s\|.-]+)` + regexp.QuoteMeta(r.Right)) + } else { + // no right tag. eg: $name, $user.age + r.varReg = regexp.MustCompile(regexp.QuoteMeta(r.Left) + `(\w[\w-]*(?:\.[\w-]+)*)`) + } + } + + return r +} + +// ParseVars the text contents and collect vars +func (r *VarReplacer) ParseVars(s string) []string { + ss := narr.StringsMap(r.varReg.FindAllString(s, -1), func(val string) string { + return strings.TrimSpace(val[r.lLen : len(val)-r.rLen]) + }) + + return narr.Unique(ss) +} + +// Render any-map vars in the text contents +func (r *VarReplacer) Render(s string, tplVars map[string]any) string { + return r.Replace(s, tplVars) +} + +// Replace any-map vars in the text contents +func (r *VarReplacer) Replace(s string, tplVars map[string]any) string { + if !strings.Contains(s, r.Left) { + return s + } + if !r.parseDef && len(tplVars) == 0 { + return s + } + + var varMap map[string]string + + if r.flatSubs { + varMap = make(map[string]string, len(tplVars)*2) + nmap.FlatWithFunc(tplVars, func(path string, val reflect.Value) { + if val.Kind() == reflect.String { + if r.parseEnv { + varMap[path] = common.ParseEnvVar(val.String(), nil) + } else { + varMap[path] = val.String() + } + } else { + varMap[path] = nstr.SafeString(val.Interface()) + } + }) + } else { + varMap = nmap.ToStringMap(tplVars) + } + + return r.Init().doReplace(s, varMap) +} + +// ReplaceSMap string-map vars in the text contents +func (r *VarReplacer) ReplaceSMap(s string, varMap map[string]string) string { + return r.RenderSimple(s, varMap) +} + +// RenderSimple string-map vars in the text contents. alias of ReplaceSMap() +func (r *VarReplacer) RenderSimple(s string, varMap map[string]string) string { + if len(varMap) == 0 || !strings.Contains(s, r.Left) { + return s + } + + if r.parseEnv { + for name, val := range varMap { + varMap[name] = common.ParseEnvVar(val, nil) + } + } + + return r.Init().doReplace(s, varMap) +} + +// MissVars list +func (r *VarReplacer) MissVars() []string { + return r.missVars +} + +// ResetMissVars list +func (r *VarReplacer) ResetMissVars() { + r.missVars = make([]string, 0) +} + +// Replace string-map vars in the text contents +func (r *VarReplacer) doReplace(s string, varMap map[string]string) string { + if !r.keepMissVars { + r.missVars = make([]string, 0) // clear on each replace + } + + return r.varReg.ReplaceAllStringFunc(s, func(sub string) string { + name := strings.TrimSpace(sub[r.lLen : len(sub)-r.rLen]) + + var defVal string + if r.parseDef && strings.ContainsRune(name, '|') { + name, defVal = nstr.TrimCut(name, "|") + } + + if val, ok := varMap[name]; ok { + return val + } + + if r.NotFound != nil { + if val, ok := r.NotFound(name); ok { + return val + } + } + + if len(defVal) > 0 { + return defVal + } + r.missVars = append(r.missVars, name) + return sub + }) +} diff --git a/nstr/util.go b/nstr/util.go new file mode 100644 index 0000000..416e5b1 --- /dev/null +++ b/nstr/util.go @@ -0,0 +1,333 @@ +package nstr + +import ( + "bytes" + "encoding/json" + "fmt" + "strings" + "text/template" + "unicode" + "unicode/utf8" +) + +// OrCond return s1 on cond is True, OR return s2. +// Like: cond ? s1 : s2 +func OrCond(cond bool, s1, s2 string) string { + if cond { + return s1 + } + return s2 +} + +// OrElse return s OR orVal(new-value) on s is empty +func OrElse(s, orVal string) string { + if s != "" { + return s + } + return orVal +} + +// OrHandle return fn(s) on s is not empty. +func OrHandle(s string, fn func(s string) string) string { + if s != "" { + return fn(s) + } + return s +} + +// Valid return first not empty element. +func Valid(ss ...string) string { + for _, s := range ss { + if s != "" { + return s + } + } + return "" +} + +// Replaces replace multi strings +// +// pairs: {old1: new1, old2: new2, ...} +// +// Can also use: +// +// strings.NewReplacer("old1", "new1", "old2", "new2").Replace(str) +func Replaces(str string, pairs map[string]string) string { + return NewReplacer(pairs).Replace(str) +} + +// NewReplacer instance +func NewReplacer(pairs map[string]string) *strings.Replacer { + ss := make([]string, len(pairs)*2) + for old, newVal := range pairs { + ss = append(ss, old, newVal) + } + return strings.NewReplacer(ss...) +} + +// PrettyJSON get pretty Json string +// Deprecated: please use fmtutil.PrettyJSON() or jsonutil.Pretty() instead it +func PrettyJSON(v any) (string, error) { + out, err := json.MarshalIndent(v, "", " ") + return string(out), err +} + +// RenderTemplate render text template +func RenderTemplate(input string, data any, fns template.FuncMap, isFile ...bool) string { + return RenderText(input, data, fns, isFile...) +} + +// RenderText render text template +func RenderText(input string, data any, fns template.FuncMap, isFile ...bool) string { + t := template.New("simple-text") + t.Funcs(template.FuncMap{ + // don't escape content + "raw": func(s string) string { + return s + }, + "trim": func(s string) string { + return strings.TrimSpace(s) + }, + // join strings + "join": func(ss []string, sep string) string { + return strings.Join(ss, sep) + }, + // lower first char + "lcFirst": func(s string) string { + return LowerFirst(s) + }, + // upper first char + "upFirst": func(s string) string { + return UpperFirst(s) + }, + }) + + // add custom template functions + if len(fns) > 0 { + t.Funcs(fns) + } + + if len(isFile) > 0 && isFile[0] { + template.Must(t.ParseFiles(input)) + } else { + template.Must(t.Parse(input)) + } + + // use buffer receive rendered content + buf := new(bytes.Buffer) + if err := t.Execute(buf, data); err != nil { + panic(err) + } + return buf.String() +} + +// WrapTag for given string. +func WrapTag(s, tag string) string { + if s == "" { + return s + } + return fmt.Sprintf("<%s>%s", tag, s, tag) +} + +// SplitSame split string into strings with the same type of unicode character +// example: user-define -> [user,define] +func SplitSame(s string, upperCase bool) []string { + var runes [][]rune + lastCharType := 0 + charType := 0 + + // split into fields based on type of unicode character + for _, r := range s { + switch true { + case RuneIsLower(r): + charType = 1 + case RuneIsUpper(r): + charType = 2 + case RuneIsDigit(r): + charType = 3 + default: + charType = 4 + } + + if charType == lastCharType { + runes[len(runes)-1] = append(runes[len(runes)-1], r) + } else { + runes = append(runes, []rune{r}) + } + lastCharType = charType + } + + for i := 0; i < len(runes)-1; i++ { + if RuneIsUpper(runes[i][0]) && RuneIsLower(runes[i+1][0]) { + runes[i+1] = append([]rune{runes[i][len(runes[i])-1]}, runes[i+1]...) + runes[i] = runes[i][:len(runes[i])-1] + } + } + + // filter all none letters and none digit + var result []string + for _, rs := range runes { + if len(rs) > 0 && (unicode.IsLetter(rs[0]) || RuneIsDigit(rs[0])) { + if upperCase { + result = append(result, string(RuneToUpperAll(rs))) + } else { + result = append(result, string(RuneToLowerAll(rs))) + } + } + } + + return result +} + +// CamelCase coverts string to camelCase string. Non letters and numbers will be ignored. +func CamelCase(s string) string { + var builder strings.Builder + + strs := SplitSame(s, false) + for i, str := range strs { + if i == 0 { + builder.WriteString(strings.ToLower(str)) + } else { + builder.WriteString(Capitalize(str)) + } + } + + return builder.String() +} + +// Capitalize converts the first character of a string to upper case and the remaining to lower case. +func Capitalize(s string) string { + result := make([]rune, len(s)) + for i, v := range s { + if i == 0 { + result[i] = unicode.ToUpper(v) + } else { + result[i] = unicode.ToLower(v) + } + } + return string(result) +} + +// KebabCase coverts string to kebab-case, non letters and numbers will be ignored. +func KebabCase(s string) string { + result := SplitSame(s, false) + return strings.Join(result, "-") +} + +// UpperKebabCase coverts string to upper KEBAB-CASE, non letters and numbers will be ignored +func UpperKebabCase(s string) string { + result := SplitSame(s, true) + return strings.Join(result, "-") +} + +// SnakeCase coverts string to snake_case, non letters and numbers will be ignored +func SnakeCase(s string) string { + result := SplitSame(s, false) + return strings.Join(result, "_") +} + +// UpperSnakeCase coverts string to upper SNAKE_CASE, non letters and numbers will be ignored +func UpperSnakeCase(s string) string { + result := SplitSame(s, true) + return strings.Join(result, "_") +} + +// UpperFirst converts the first character of string to upper case. +func UpperFirst(s string) string { + if len(s) == 0 { + return "" + } + + r, size := utf8.DecodeRuneInString(s) + r = unicode.ToUpper(r) + + return string(r) + s[size:] +} + +// LowerFirst converts the first character of string to lower case. +func LowerFirst(s string) string { + if len(s) == 0 { + return "" + } + + r, size := utf8.DecodeRuneInString(s) + r = unicode.ToLower(r) + + return string(r) + s[size:] +} + +// Before returns the substring of the source string up to the first occurrence of the specified string. +func Before(s, char string) string { + if s == "" || char == "" { + return s + } + i := strings.Index(s, char) + return s[0:i] +} + +// BeforeLast returns the substring of the source string up to the last occurrence of the specified string. +func BeforeLast(s, char string) string { + if s == "" || char == "" { + return s + } + i := strings.LastIndex(s, char) + return s[0:i] +} + +// After returns the substring after the first occurrence of a specified string in the source string. +func After(s, char string) string { + if s == "" || char == "" { + return s + } + i := strings.Index(s, char) + return s[i+len(char):] +} + +// AfterLast returns the substring after the last occurrence of a specified string in the source string. +func AfterLast(s, char string) string { + if s == "" || char == "" { + return s + } + i := strings.LastIndex(s, char) + return s[i+len(char):] +} + +// Reverse returns string whose char order is reversed to the given string. +func Reverse(s string) string { + r := []rune(s) + for i, j := 0, len(r)-1; i < j; i, j = i+1, j-1 { + r[i], r[j] = r[j], r[i] + } + return string(r) +} + +// Wrap a string with given string. +func Wrap(str string, wrapWith string) string { + if str == "" || wrapWith == "" { + return str + } + var sb strings.Builder + sb.WriteString(wrapWith) + sb.WriteString(str) + sb.WriteString(wrapWith) + + return sb.String() +} + +// Unwrap a given string from another string. will change source string. +func Unwrap(str string, wrapToken string) string { + if str == "" || wrapToken == "" { + return str + } + + firstIndex := strings.Index(str, wrapToken) + lastIndex := strings.LastIndex(str, wrapToken) + + if firstIndex == 0 && lastIndex > 0 && lastIndex <= len(str)-1 { + if len(wrapToken) <= lastIndex { + str = str[len(wrapToken):lastIndex] + } + } + + return str +} diff --git a/nstruct/check.go b/nstruct/check.go new file mode 100644 index 0000000..27f0298 --- /dev/null +++ b/nstruct/check.go @@ -0,0 +1,14 @@ +package nstruct + +// IsExported field name on struct +func IsExported(name string) bool { + return name[0] >= 'A' && name[0] <= 'Z' +} + +// IsUnexported field name on struct +func IsUnexported(name string) bool { + if name[0] == '_' { + return true + } + return name[0] >= 'a' && name[0] <= 'z' +} diff --git a/nstruct/init.go b/nstruct/init.go new file mode 100644 index 0000000..93345e3 --- /dev/null +++ b/nstruct/init.go @@ -0,0 +1,178 @@ +package nstruct + +import ( + "errors" + "git.noahlan.cn/noahlan/ntool/nenv" + "git.noahlan.cn/noahlan/ntool/nreflect" + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" +) + +const defaultInitTag = "default" + +// InitOptFunc define +type InitOptFunc func(opt *InitOptions) + +// InitOptions struct +type InitOptions struct { + // TagName default value tag name. tag: default + TagName string + // ParseEnv var name on default value. eg: `default:"${APP_ENV}"` + // + // default: false + ParseEnv bool + // ValueHook before set value hook TODO + ValueHook func(val string) any +} + +// Init struct default value by field "default" tag. +func Init(ptr any, optFns ...InitOptFunc) error { + return InitDefaults(ptr, optFns...) +} + +// InitDefaults init struct default value by field "default" tag. +// +// TIPS: +// +// Support init field types: string, bool, intX, uintX, floatX, array, slice +// +// Example: +// +// type User1 struct { +// Name string `default:"inhere"` +// Age int32 `default:"30"` +// } +// +// u1 := &User1{} +// err = structs.InitDefaults(u1) +// fmt.Printf("%+v\n", u1) // Output: {Name:inhere Age:30} +func InitDefaults(ptr any, optFns ...InitOptFunc) error { + rv := reflect.ValueOf(ptr) + if rv.Kind() != reflect.Ptr { + return errors.New("must be provider an pointer value") + } + + rv = rv.Elem() + if rv.Kind() != reflect.Struct { + return errors.New("must be provider an struct value") + } + + opt := &InitOptions{TagName: defaultInitTag} + for _, fn := range optFns { + fn(opt) + } + + return initDefaults(rv, opt) +} + +func initDefaults(rv reflect.Value, opt *InitOptions) error { + rt := rv.Type() + + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + // skip don't exported field + if IsUnexported(sf.Name) { + continue + } + + val, hasTag := sf.Tag.Lookup(opt.TagName) + if !hasTag || val == "-" { + continue + } + + fv := rv.Field(i) + if fv.Kind() == reflect.Struct { + if err := initDefaults(fv, opt); err != nil { + return err + } + continue + } + + // skip on field has value + if !fv.IsZero() { + // special: handle for pointer struct field + if fv.Kind() == reflect.Pointer { + fv = fv.Elem() + if fv.Kind() == reflect.Struct { + if err := initDefaults(fv, opt); err != nil { + return err + } + } + } + continue + } + + // handle for pointer field + if fv.Kind() == reflect.Pointer { + if fv.IsNil() { + fv.Set(reflect.New(fv.Type().Elem())) + } + + fv = fv.Elem() + if fv.Kind() == reflect.Struct { + if err := initDefaults(fv, opt); err != nil { + return err + } + continue + } + } else if fv.Kind() == reflect.Slice { + el := sf.Type.Elem() + isPtr := el.Kind() == reflect.Pointer + if isPtr { + el = el.Elem() + } + + // init sub struct in slice. like `[]SubStruct` or `[]*SubStruct` + if el.Kind() == reflect.Struct { + // make sub-struct and init. like: `SubStruct` + subFv := reflect.New(el) + subFvE := subFv.Elem() + if err := initDefaults(subFvE, opt); err != nil { + return err + } + + // make new slice and set value. + newFv := reflect.MakeSlice(reflect.SliceOf(sf.Type.Elem()), 0, 1) + if isPtr { + newFv = reflect.Append(newFv, subFv) + } else { + newFv = reflect.Append(newFv, subFvE) + } + fv.Set(newFv) + continue + } + } + + if err := initDefaultValue(fv, val, opt.ParseEnv); err != nil { + return err + } + } + + return nil +} + +func initDefaultValue(fv reflect.Value, val string, parseEnv bool) error { + if val == "" || !fv.CanSet() { + return nil + } + + // parse env var + if parseEnv { + val = nenv.ParseEnvVar(val, nil) + } + + var anyVal any = val + + // simple slice: convert simple kind(string,intX,uintX,...) to slice. eg: "1,2,3" => []int{1,2,3} + if nreflect.IsArrayOrSlice(fv.Kind()) && nreflect.IsSimpleKind(nreflect.SliceElemKind(fv.Type())) { + ss := nstr.SplitTrimmed(val, ",") + valRv, err := nreflect.ConvSlice(reflect.ValueOf(ss), fv.Type().Elem()) + if err == nil { + nreflect.SetRValue(fv, valRv) + } + return err + } + + // set value + return nreflect.SetValue(fv, anyVal) +} diff --git a/nstruct/mapconv.go b/nstruct/mapconv.go new file mode 100644 index 0000000..ac39249 --- /dev/null +++ b/nstruct/mapconv.go @@ -0,0 +1,173 @@ +package nstruct + +import ( + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/nreflect" + "reflect" +) + +// ToMap quickly convert structs to map by reflect +func ToMap(st any, optFns ...MapOptFunc) map[string]any { + mp, _ := StructToMap(st, optFns...) + return mp +} + +// MustToMap alis of TryToMap, but will panic on error +func MustToMap(st any, optFns ...MapOptFunc) map[string]any { + mp, err := StructToMap(st, optFns...) + if err != nil { + panic(err) + } + return mp +} + +// TryToMap simple convert structs to map by reflect +func TryToMap(st any, optFns ...MapOptFunc) (map[string]any, error) { + return StructToMap(st, optFns...) +} + +// ToSMap quickly and safe convert structs to map[string]string by reflect +func ToSMap(st any, optFns ...MapOptFunc) map[string]string { + mp, _ := StructToMap(st, optFns...) + return nmap.ToStringMap(mp) +} + +// TryToSMap quickly convert structs to map[string]string by reflect +func TryToSMap(st any, optFns ...MapOptFunc) (map[string]string, error) { + mp, err := StructToMap(st, optFns...) + if err != nil { + return nil, err + } + return nmap.ToStringMap(mp), nil +} + +// MustToSMap alias of ToStringMap(), but will panic on error +func MustToSMap(st any, optFns ...MapOptFunc) map[string]string { + mp, err := StructToMap(st, optFns...) + if err != nil { + panic(err) + } + return nmap.ToStringMap(mp) +} + +// ToString quickly format struct to string +func ToString(st any, optFns ...MapOptFunc) string { + mp, err := StructToMap(st, optFns...) + if err == nil { + return nmap.ToString(mp) + } + return fmt.Sprint(st) +} + +const defaultFieldTag = "json" + +// MapOptions for convert struct to map +type MapOptions struct { + // TagName for map filed. default is "json" + TagName string + // ParseDepth for parse. TODO support depth + ParseDepth int + // MergeAnonymous struct fields to parent map. default is true + MergeAnonymous bool + // ExportPrivate export private fields. default is false + ExportPrivate bool +} + +// MapOptFunc define +type MapOptFunc func(opt *MapOptions) + +// WithMapTagName set tag name for map field +func WithMapTagName(tagName string) MapOptFunc { + return func(opt *MapOptions) { + opt.TagName = tagName + } +} + +// MergeAnonymous merge anonymous struct fields to parent map +func MergeAnonymous(opt *MapOptions) { + opt.MergeAnonymous = true +} + +// ExportPrivate merge anonymous struct fields to parent map +func ExportPrivate(opt *MapOptions) { + opt.ExportPrivate = true +} + +// StructToMap quickly convert structs to map[string]any by reflect. +// Can custom export field name by tag `json` or custom tag +func StructToMap(st any, optFns ...MapOptFunc) (map[string]any, error) { + mp := make(map[string]any) + if st == nil { + return mp, nil + } + + obj := reflect.Indirect(reflect.ValueOf(st)) + if obj.Kind() != reflect.Struct { + return mp, errors.New("must be an struct value") + } + + opt := &MapOptions{TagName: defaultFieldTag} + for _, fn := range optFns { + fn(opt) + } + + _, err := structToMap(obj, opt, mp) + return mp, err +} + +func structToMap(obj reflect.Value, opt *MapOptions, mp map[string]any) (map[string]any, error) { + if mp == nil { + mp = make(map[string]any) + } + + refType := obj.Type() + for i := 0; i < obj.NumField(); i++ { + ft := refType.Field(i) + name := ft.Name + // skip un-exported field + if !opt.ExportPrivate && IsUnexported(name) { + continue + } + + tagVal, ok := ft.Tag.Lookup(opt.TagName) + if ok && tagVal != "" { + sMap, err := ParseTagValueDefault(name, tagVal) + if err != nil { + return nil, err + } + + name = sMap.Default("name", name) + if name == "" { // un-exported field + continue + } + } + + field := reflect.Indirect(obj.Field(i)) + if field.Kind() == reflect.Struct { + // collect anonymous struct values to parent. + if ft.Anonymous && opt.MergeAnonymous { + _, err := structToMap(field, opt, mp) + if err != nil { + return nil, err + } + } else { // collect struct values to submap + sub, err := structToMap(field, opt, nil) + if err != nil { + return nil, err + } + mp[name] = sub + } + continue + } + + if field.CanInterface() { + mp[name] = field.Interface() + } else if field.CanAddr() { // for unexported field + mp[name] = nreflect.UnexportedValue(field) + } + } + + return mp, nil +} diff --git a/nstruct/tags.go b/nstruct/tags.go new file mode 100644 index 0000000..acf57e0 --- /dev/null +++ b/nstruct/tags.go @@ -0,0 +1,273 @@ +package nstruct + +import ( + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/nstr" + "reflect" + "strings" +) + +// ErrNotAnStruct error +// var emptyStringMap = make(nmap.SMap) +var ErrNotAnStruct = errors.New("must input an struct value") + +// ParseTags for parse struct tags. +func ParseTags(st any, tagNames []string) (map[string]nmap.SMap, error) { + p := NewTagParser(tagNames...) + + if err := p.Parse(st); err != nil { + return nil, err + } + return p.Tags(), nil +} + +// ParseReflectTags parse struct tags info. +func ParseReflectTags(rt reflect.Type, tagNames []string) (map[string]nmap.SMap, error) { + p := NewTagParser(tagNames...) + + if err := p.ParseType(rt); err != nil { + return nil, err + } + return p.Tags(), nil +} + +// TagValFunc handle func +type TagValFunc func(field, tagVal string) (nmap.SMap, error) + +// TagParser struct +type TagParser struct { + // TagNames want parsed tag names. + TagNames []string + // ValueFunc tag value parse func. + ValueFunc TagValFunc + + // key: field name + // value: tag map {tag-name: value string.} + tags map[string]nmap.SMap +} + +// Tags map data for struct fields +func (p *TagParser) Tags() map[string]nmap.SMap { + return p.tags +} + +// NewTagParser instance +func NewTagParser(tagNames ...string) *TagParser { + return &TagParser{ + TagNames: tagNames, + ValueFunc: ParseTagValueDefault, + } +} + +// Parse an struct value +func (p *TagParser) Parse(st any) error { + rv := reflect.ValueOf(st) + if rv.Kind() == reflect.Ptr && !rv.IsNil() { + rv = rv.Elem() + } + + return p.ParseType(rv.Type()) +} + +// ParseType parse a struct type value +func (p *TagParser) ParseType(rt reflect.Type) error { + if rt.Kind() != reflect.Struct { + return ErrNotAnStruct + } + + // key is field name. + p.tags = make(map[string]nmap.SMap) + return p.parseType(rt, "") +} + +func (p *TagParser) parseType(rt reflect.Type, parent string) error { + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + + // skip don't exported field + name := sf.Name + if name[0] >= 'a' && name[0] <= 'z' { + continue + } + + smp := make(nmap.SMap) + for _, tagName := range p.TagNames { + // eg: `json:"age"` + // eg: "name=int0;shorts=i;required=true;desc=int option message" + tagVal := sf.Tag.Get(tagName) + if tagVal == "" { + continue + } + + smp[tagName] = tagVal + } + + pathKey := name + if parent != "" { + pathKey = parent + "." + name + } + + p.tags[pathKey] = smp + + ft := sf.Type + if ft.Kind() == reflect.Ptr { + ft = ft.Elem() + } + + // field is struct. + if ft.Kind() == reflect.Struct { + err := p.parseType(ft, pathKey) + if err != nil { + return err + } + } + } + return nil +} + +// Info parse the give field, returns tag value info. +// +// info, err := p.Info("Name", "json") +// exportField := info.Get("name") +func (p *TagParser) Info(field, tag string) (nmap.SMap, error) { + field = nstr.UpperFirst(field) + fTags, ok := p.tags[field] + if !ok { + return nil, fmt.Errorf("field %q not found", field) + } + + val, ok := fTags.Value(tag) + if !ok { + return make(nmap.SMap), nil + } + + // parse tag value + return p.ValueFunc(field, val) +} + +/************************************************************* + * some built in tag value parse func + *************************************************************/ + +// ParseTagValueDefault parse like json tag value. +// +// see json.Marshal(): +// +// // JSON as key "myName", skipped if empty. +// Field int `json:"myName,omitempty"` +// +// // Field appears in JSON as key "Field" (the default), but skipped if empty. +// Field int `json:",omitempty"` +// +// // Field is ignored by this package. +// Field int `json:"-"` +// +// // Field appears in JSON as key "-". +// Field int `json:"-,"` +// +// Int64String int64 `json:",string"` +// +// Returns: +// +// { +// "name": "myName", // maybe is empty, on tag value is "-" +// "omitempty": "true", +// "string": "true", +// // ... more custom bool settings. +// } +func ParseTagValueDefault(field, tagVal string) (mp nmap.SMap, err error) { + ss := nstr.SplitTrimmed(tagVal, ",") + ln := len(ss) + if ln == 0 || tagVal == "," { + return nmap.SMap{"name": field}, nil + } + + mp = make(nmap.SMap, ln) + if ln == 1 { + // valid field name + if ss[0] != "-" { + mp["name"] = ss[0] + } + return + } + + // ln > 1 + mp["name"] = ss[0] + // other settings: omitempty, string + for _, key := range ss[1:] { + mp[key] = "true" + } + return +} + +// ParseTagValueQuick quick parse tag value string by sep(;) +func ParseTagValueQuick(tagVal string, defines []string) nmap.SMap { + parseFn := ParseTagValueDefine(";", defines) + + mp, _ := parseFn("", tagVal) + return mp +} + +// ParseTagValueDefine parse tag value string by given defines. +// +// Examples: +// +// eg: "desc;required;default;shorts" +// type MyStruct { +// Age int `flag:"int option message;;a,b"` +// } +// sepStr := ";" +// defines := []string{"desc", "required", "default", "shorts"} +func ParseTagValueDefine(sep string, defines []string) TagValFunc { + defNum := len(defines) + + return func(field, tagVal string) (nmap.SMap, error) { + ss := nstr.SplitNTrimmed(tagVal, sep, defNum) + ln := len(ss) + mp := make(nmap.SMap, ln) + if ln == 0 { + return mp, nil + } + + for i, val := range ss { + key := defines[i] + mp[key] = val + } + return mp, nil + } +} + +// ParseTagValueNamed parse k-v tag value string. it's like INI format contents. +// +// Examples: +// +// eg: "name=val0;shorts=i;required=true;desc=a message" +// => +// {name: val0, shorts: i, required: true, desc: a message} +func ParseTagValueNamed(field, tagVal string, keys ...string) (mp nmap.SMap, err error) { + ss := nstr.Split(tagVal, ";") + ln := len(ss) + if ln == 0 { + return + } + + mp = make(nmap.SMap, ln) + for _, s := range ss { + if !strings.ContainsRune(s, '=') { + err = fmt.Errorf("parse tag error on field '%s': must match `KEY=VAL`", field) + return + } + + key, val := nstr.TrimCut(s, "=") + if len(keys) > 0 && !narr.StringsHas(keys, key) { + err = fmt.Errorf("parse tag error on field '%s': invalid key name '%s'", field, key) + return + } + + mp[key] = val + } + return +} diff --git a/nsys/atomic/atomic_duration.go b/nsys/atomic/atomic_duration.go new file mode 100644 index 0000000..16c1214 --- /dev/null +++ b/nsys/atomic/atomic_duration.go @@ -0,0 +1,36 @@ +package atomic + +import ( + "sync/atomic" + "time" +) + +// An AtomicDuration is an implementation of atomic duration. +type AtomicDuration int64 + +// NewAtomicDuration returns an AtomicDuration. +func NewAtomicDuration() *AtomicDuration { + return new(AtomicDuration) +} + +// ForAtomicDuration returns an AtomicDuration with given value. +func ForAtomicDuration(val time.Duration) *AtomicDuration { + d := NewAtomicDuration() + d.Set(val) + return d +} + +// CompareAndSwap compares current value with old, if equals, set the value to val. +func (d *AtomicDuration) CompareAndSwap(old, val time.Duration) bool { + return atomic.CompareAndSwapInt64((*int64)(d), int64(old), int64(val)) +} + +// Load loads the current duration. +func (d *AtomicDuration) Load() time.Duration { + return time.Duration(atomic.LoadInt64((*int64)(d))) +} + +// Set sets the value to val. +func (d *AtomicDuration) Set(val time.Duration) { + atomic.StoreInt64((*int64)(d), int64(val)) +} diff --git a/nsys/atomic/atomic_int64.go b/nsys/atomic/atomic_int64.go new file mode 100644 index 0000000..7bcf27d --- /dev/null +++ b/nsys/atomic/atomic_int64.go @@ -0,0 +1,60 @@ +package atomic + +import "sync/atomic" + +type AtomicInt64 struct { + val *atomic.Int64 +} + +func NewAtomicInt64() *AtomicInt64 { + return &AtomicInt64{val: &atomic.Int64{}} +} + +func (a *AtomicInt64) Reset() { + a.val.Store(0) +} + +// AddAndGet 以原子方式将当前值+delta,并返回+delta之后的值 +func (a *AtomicInt64) AddAndGet(delta int64) int64 { + return a.val.Add(delta) +} + +// GetAndAdd 以原子方式将当前值+delta,并返回+delta之前的值 +func (a *AtomicInt64) GetAndAdd(delta int64) int64 { + tmp := a.val.Load() + a.val.Add(delta) + return tmp +} + +// UpdateAndGet 以原子方式更新值,并返回更新后的值 +func (a *AtomicInt64) UpdateAndGet(val int64) int64 { + a.val.Store(val) + return val +} + +// GetAndUpdate 以原子方式更新值,并返回更新前的值 +func (a *AtomicInt64) GetAndUpdate(val int64) int64 { + tmp := a.val.Load() + a.val.Store(val) + return tmp +} + +// GetAndIncrement 以原子方式将当前值+1,并返回+1前的值 +func (a *AtomicInt64) GetAndIncrement() int64 { + return a.GetAndAdd(1) +} + +// GetAndDecrement 以原子方式将当前值-1,并返回-1前的值 +func (a *AtomicInt64) GetAndDecrement() int64 { + return a.GetAndAdd(-1) +} + +// IncrementAndGet 以原子方式将当前值+1,并返回+1后的值 +func (a *AtomicInt64) IncrementAndGet() int64 { + return a.AddAndGet(1) +} + +// DecrementAndGet 以原子方式将当前值-1,并返回-1后的值 +func (a *AtomicInt64) DecrementAndGet() int64 { + return a.AddAndGet(-1) +} diff --git a/nsys/clipboard/clipboard.go b/nsys/clipboard/clipboard.go new file mode 100644 index 0000000..26f09a7 --- /dev/null +++ b/nsys/clipboard/clipboard.go @@ -0,0 +1,210 @@ +package clipboard + +import ( + "bytes" + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/ncli" + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/nsys" + "github.com/gookit/color" + "io" + "os/exec" + "strings" +) + +// Clipboard struct +type Clipboard struct { + // TODO add event on write, read + // buffer for write + buf *bytes.Buffer + + // print exec command line on run + verbose bool + // available - bin file exist on the OS. + writeable, readable bool + + readerBin string + readArgs []string + writerBin string + writeArgs []string +} + +// New instance +func New() *Clipboard { + // special handle on with args + reader, readArgs := parseLine(GetReaderBin()) + writer, writeArgs := parseLine(GetWriterBin()) + + return &Clipboard{ + readerBin: reader, + readArgs: readArgs, + writerBin: writer, + writeArgs: writeArgs, + readable: nsys.HasExecutable(reader), + writeable: nsys.HasExecutable(writer), + } +} + +// WithVerbose setting +func (c *Clipboard) WithVerbose(yn bool) *Clipboard { + c.verbose = yn + return c +} + +// Clean the clipboard +func (c *Clipboard) Clean() error { return c.Reset() } + +// Reset and clean the clipboard +func (c *Clipboard) Reset() error { + if c.buf != nil { + c.buf.Reset() + } + + // echo empty string for clean clipboard. + // run: echo '' | pbcopy + return c.WriteFrom(strings.NewReader("")) +} + +// +// ---------------------------------------- write ---------------------------------------- +// + +// Write bytes data to clipboard +func (c *Clipboard) Write(p []byte) (int, error) { + return c.WriteString(string(p)) +} + +// WriteString data to clipboard +func (c *Clipboard) WriteString(s string) (int, error) { + // if c.addSlashes { + // s = strutil.AddSlashes(s) + // } + return c.buffer().WriteString(s) +} + +// Flush buffer contents to clipboard +func (c *Clipboard) Flush() error { + if c.buf == nil || c.buf.Len() == 0 { + return errors.New("clipboard: empty contents for write") + } + + defer c.buf.Reset() + return c.WriteFrom(c.buf) +} + +// WriteFromFile contents to clipboard +func (c *Clipboard) WriteFromFile(filepath string) error { + // eg: + // Mac: pbcopy < tempfile.txt + file, err := nfs.OpenReadFile(filepath) + if err != nil { + return err + } + + defer file.Close() + return c.WriteFrom(file) +} + +// WriteFrom reader data to clipboard +func (c *Clipboard) WriteFrom(r io.Reader) error { + if !c.writeable { + return fmt.Errorf("clipboard: write driver %q not found on OS", c.writerBin) + } + + cmd := exec.Command(c.writerBin, c.writeArgs...) + cmd.Stdin = r + + if c.verbose { + color.Yellow.Printf("clipboard> %s\n", ncli.BuildLine(c.writerBin, c.writeArgs)) + } + return cmd.Run() +} + +// +// ---------------------------------------- read ---------------------------------------- +// + +// Read contents from clipboard +func (c *Clipboard) Read() ([]byte, error) { + buf, err := c.ReadToBuffer() + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// ReadToBuffer contents from clipboard +func (c *Clipboard) ReadToBuffer() (*bytes.Buffer, error) { + var buf bytes.Buffer + if err := c.ReadTo(&buf); err != nil { + return nil, err + } + return &buf, nil +} + +// ReadString contents as string from clipboard +func (c *Clipboard) ReadString() (string, error) { + bts, err := c.Read() + if err != nil { + return "", err + } + + // fix: at Windows will always return end of the "\r\n" + if nsys.IsWindows() { + return string(bytes.TrimRight(bts, "\r\n")), nil + } + return string(bts), nil +} + +// ReadToFile dump clipboard data to file +func (c *Clipboard) ReadToFile(filepath string) error { + file, err := nfs.QuickOpenFile(filepath) + if err != nil { + return err + } + + defer file.Close() + return c.ReadTo(file) +} + +// ReadTo read clipboard contents to writer +func (c *Clipboard) ReadTo(w io.Writer) error { + if !c.readable { + return fmt.Errorf("clipboard: read driver %q not found on OS", c.readerBin) + } + + cmd := exec.Command(c.readerBin, c.readArgs...) + cmd.Stdout = w + + if c.verbose { + color.Yellow.Printf("clipboard> %s\n", ncli.BuildLine(c.writerBin, c.writeArgs)) + } + return cmd.Run() +} + +// +// ---------------------------------------- help ---------------------------------------- +// + +// Available check +func (c *Clipboard) Available() bool { + return c.writeable && c.readable && available() +} + +// Writeable check +func (c *Clipboard) Writeable() bool { + return c.writeable +} + +// Readable check +func (c *Clipboard) Readable() bool { + return c.readable +} + +func (c *Clipboard) buffer() *bytes.Buffer { + if c.buf == nil { + c.buf = new(bytes.Buffer) + } + return c.buf +} diff --git a/nsys/clipboard/clipboard_test.go b/nsys/clipboard/clipboard_test.go new file mode 100644 index 0000000..14b0e47 --- /dev/null +++ b/nsys/clipboard/clipboard_test.go @@ -0,0 +1,40 @@ +package clipboard_test + +import ( + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/nstr" + "git.noahlan.cn/noahlan/ntool/nsys/clipboard" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestClipboard_WriteFromFile(t *testing.T) { + cb := clipboard.New() + if ok := cb.Available(); !ok { + assert.False(t, ok) + t.Skipf("skip test on program '%s' not found", clipboard.GetReaderBin()) + return + } + + srcFile := "testdata/testcb.txt" + srcStr := string(nfs.MustReadFile(srcFile)) + assert.NotEmpty(t, srcStr) + + err := cb.WriteFromFile(srcFile) + assert.NoErr(t, err) + + err = cb.WriteFromFile("path/to/not-exists.txt") + assert.Err(t, err) + + readStr, err := cb.ReadString() + assert.NoErr(t, err) + assert.Eq(t, srcStr, nstr.Trim(readStr)) + + dstFile := "testdata/read-from-cb.txt" + assert.NoErr(t, nfs.RmFileIfExist(dstFile)) + err = cb.ReadToFile(dstFile) + assert.NoErr(t, err) + + dstStr := string(nfs.MustReadFile(dstFile)) + assert.Eq(t, srcStr, nstr.Trim(dstStr)) +} diff --git a/nsys/clipboard/testdata/read-from-cb.txt b/nsys/clipboard/testdata/read-from-cb.txt new file mode 100644 index 0000000..67b5752 --- /dev/null +++ b/nsys/clipboard/testdata/read-from-cb.txt @@ -0,0 +1 @@ +hi, contents from file diff --git a/nsys/clipboard/testdata/testcb.txt b/nsys/clipboard/testdata/testcb.txt new file mode 100644 index 0000000..8468558 --- /dev/null +++ b/nsys/clipboard/testdata/testcb.txt @@ -0,0 +1 @@ +hi, contents from file \ No newline at end of file diff --git a/nsys/clipboard/util.go b/nsys/clipboard/util.go new file mode 100644 index 0000000..e2fc34b --- /dev/null +++ b/nsys/clipboard/util.go @@ -0,0 +1,87 @@ +package clipboard + +import "strings" + +// clipboard writer, reader program names +const ( + // WriterOnMac driver + // + // Example: + // echo hello | pbcopy + // pbcopy < tempfile.txt + WriterOnMac = "pbcopy" + + // WriterOnWin driver on Windows + // + // TIP: clip only support write contents to clipboard. + WriterOnWin = "clip" + + // WriterOnLin driver name + // + // linux: + // echo "hello-c" | xclip -selection c + WriterOnLin = "xclip -selection clipboard" + + // ReaderOnMac driver + // + // Example: + // Mac: pbpaste >> tasklist.txt + ReaderOnMac = "pbpaste" + + // ReaderOnWin driver on Windows + // + // read clipboard should use: powershell get-clipboard + ReaderOnWin = "powershell get-clipboard" + + // ReaderOnLin driver name + // + // Usage: + // xclip -o -selection clipboard + // xclip -o -selection c // can use shorts + ReaderOnLin = "xclip -o -selection clipboard" +) + +var ( + writerOnLin = []string{"xclip", "xsel"} +) + +// std instance +var std = New() + +// Std get +func Std() *Clipboard { + return std +} + +// Reset clipboard data +func Reset() error { + return std.Reset() +} + +// Available clipboard available check +func Available() bool { + return std.Available() +} + +// ReadString contents from clipboard +func ReadString() (string, error) { + return std.ReadString() +} + +// WriteString contents to clipboard and flush +func WriteString(s string) error { + if _, err := std.WriteString(s); err != nil { + return err + } + return std.Flush() +} + +// special handle on with args +func parseLine(line string) (bin string, args []string) { + bin = line + if strings.ContainsRune(line, ' ') { + list := strings.Split(line, " ") + bin, args = list[0], list[1:] + } + return +} diff --git a/nsys/clipboard/util_darwin.go b/nsys/clipboard/util_darwin.go new file mode 100644 index 0000000..db8ee54 --- /dev/null +++ b/nsys/clipboard/util_darwin.go @@ -0,0 +1,15 @@ +//go:build darwin + +package clipboard + +// GetWriterBin program name +func GetWriterBin() string { + return WriterOnMac +} + +// GetReaderBin program name +func GetReaderBin() string { + return ReaderOnMac +} + +func available() bool { return true } diff --git a/nsys/clipboard/util_unix.go b/nsys/clipboard/util_unix.go new file mode 100644 index 0000000..2c2f8ea --- /dev/null +++ b/nsys/clipboard/util_unix.go @@ -0,0 +1,20 @@ +//go:build !windows && !darwin + +package clipboard + +import "os" + +// GetWriterBin program name +func GetWriterBin() string { + return WriterOnLin +} + +// GetReaderBin program name +func GetReaderBin() string { + return ReaderOnLin +} + +func available() bool { + // X clipboard is unavailable when not under X. + return os.Getenv("DISPLAY") != "" +} diff --git a/nsys/clipboard/util_windows.go b/nsys/clipboard/util_windows.go new file mode 100644 index 0000000..2a1236b --- /dev/null +++ b/nsys/clipboard/util_windows.go @@ -0,0 +1,15 @@ +//go:build windows + +package clipboard + +// GetWriterBin program name +func GetWriterBin() string { + return WriterOnWin +} + +// GetReaderBin program name +func GetReaderBin() string { + return ReaderOnWin +} + +func available() bool { return true } diff --git a/nsys/cmdn/cmd_darwin.go b/nsys/cmdn/cmd_darwin.go new file mode 100644 index 0000000..d09f7d5 --- /dev/null +++ b/nsys/cmdn/cmd_darwin.go @@ -0,0 +1,21 @@ +package cmdn + +import ( + "os/exec" + "syscall" +) + +// terminateProcess stops the command by sending its process group a SIGTERM signal. +// Stop is idempotent. An error should only be returning to the rare case that +// Stop is called immediately after the command ends but before Start can +// update its internal state. +func terminateProcess(pid int) error { + // Signal the process group (-pid), not just the process, so that the process + // and all its children are signaled. Else, child processes can keep running and + // keep the stdout/stderr fd open and cause cmd.Wait to hang. + return syscall.Kill(-pid, syscall.SIGTERM) +} + +func setProcessGroupID(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setgpid: true} +} diff --git a/nsys/cmdn/cmd_freebsd.go b/nsys/cmdn/cmd_freebsd.go new file mode 100644 index 0000000..d09f7d5 --- /dev/null +++ b/nsys/cmdn/cmd_freebsd.go @@ -0,0 +1,21 @@ +package cmdn + +import ( + "os/exec" + "syscall" +) + +// terminateProcess stops the command by sending its process group a SIGTERM signal. +// Stop is idempotent. An error should only be returning to the rare case that +// Stop is called immediately after the command ends but before Start can +// update its internal state. +func terminateProcess(pid int) error { + // Signal the process group (-pid), not just the process, so that the process + // and all its children are signaled. Else, child processes can keep running and + // keep the stdout/stderr fd open and cause cmd.Wait to hang. + return syscall.Kill(-pid, syscall.SIGTERM) +} + +func setProcessGroupID(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setgpid: true} +} diff --git a/nsys/cmdn/cmd_linux.go b/nsys/cmdn/cmd_linux.go new file mode 100644 index 0000000..37b9088 --- /dev/null +++ b/nsys/cmdn/cmd_linux.go @@ -0,0 +1,21 @@ +package cmdn + +import ( + "os/exec" + "syscall" +) + +// terminateProcess stops the command by sending its process group a SIGTERM signal. +// Stop is idempotent. An error should only be returning to the rare case that +// Stop is called immediately after the command ends but before Start can +// update its internal state. +func terminateProcess(pid int) error { + // Signal the process group (-pid), not just the process, so that the process + // and all its children are signaled. Else, child processes can keep running and + // keep the stdout/stderr fd open and cause cmd.Wait to hang. + return syscall.Kill(-pid, syscall.SIGTERM) +} + +func setProcessGroupID(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} +} diff --git a/nsys/cmdn/cmd_windows.go b/nsys/cmdn/cmd_windows.go new file mode 100644 index 0000000..169c446 --- /dev/null +++ b/nsys/cmdn/cmd_windows.go @@ -0,0 +1,23 @@ +package cmdn + +import ( + "os" + "os/exec" + "syscall" +) + +// terminateProcess stops the command by sending its process group a SIGTERM signal. +// Stop is idempotent. An error should only be returning to the rare case that +// Stop is called immediately after the command ends but before Start can +// update its internal state. +func terminateProcess(pid int) error { + p, err := os.FindProcess(pid) + if err != nil { + return err + } + return p.Kill() +} + +func setProcessGroupID(cmd *exec.Cmd) { + cmd.SysProcAttr = &syscall.SysProcAttr{} +} diff --git a/nsys/cmdn/command.go b/nsys/cmdn/command.go new file mode 100644 index 0000000..d0e4900 --- /dev/null +++ b/nsys/cmdn/command.go @@ -0,0 +1,18 @@ +package cmdn + +type ICommand interface { + // MessageID 消息ID + MessageID() string +} + +// PlainCommand 基于明文传输数据的命令 +type PlainCommand struct { + MID string // 消息ID + ID string + Cmd string + Args []string +} + +func (c *PlainCommand) MessageID() string { + return c.MID +} diff --git a/nsys/cmdn/options.go b/nsys/cmdn/options.go new file mode 100644 index 0000000..b3b1603 --- /dev/null +++ b/nsys/cmdn/options.go @@ -0,0 +1,49 @@ +package cmdn + +import ( + "git.noahlan.cn/noahlan/ntool/ndef" + "time" +) + +func WithSerializer(serializer ndef.Serializer) Option { + return func(opt *Options) { + opt.Marshaler = &ndef.MarshalerWrapper{Marshaler: serializer} + opt.Unmarshaler = &ndef.UnmarshalerWrapper{Unmarshaler: serializer} + } +} + +func WithMarshaler(marshaler ndef.Marshaler) Option { + return func(opt *Options) { + opt.Marshaler = marshaler + } +} + +func WithUnmarshaler(unmarshaler ndef.Unmarshaler) Option { + return func(opt *Options) { + opt.Unmarshaler = unmarshaler + } +} + +func WithStartupDecidedFunc(startupDecidedFunc LineFunc) Option { + return func(opt *Options) { + opt.StartupDecidedFunc = startupDecidedFunc + } +} + +func WithEndLineDecidedFunc(endLineDecidedFunc LineFunc) Option { + return func(opt *Options) { + opt.EndLineDecidedFunc = endLineDecidedFunc + } +} + +func WithReadIDFunc(readIDFunc ReadIDFunc) Option { + return func(opt *Options) { + opt.ReadIDFunc = readIDFunc + } +} + +func WithTimeout(timeout time.Duration) Option { + return func(opt *Options) { + opt.Timeout = timeout + } +} diff --git a/nsys/cmdn/proc.go b/nsys/cmdn/proc.go new file mode 100644 index 0000000..dd5b9e6 --- /dev/null +++ b/nsys/cmdn/proc.go @@ -0,0 +1,364 @@ +package cmdn + +import ( + "bufio" + "context" + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/ndef" + "git.noahlan.cn/noahlan/ntool/nstr" + atomic2 "git.noahlan.cn/noahlan/ntool/nsys/atomic" + "io" + "log" + "os/exec" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + DefaultBlockPrefix = "block" + DefaultNonBlockPrefix = "non-block" + DefaultTimeout = time.Second * 0 +) + +var ( + ErrBrokenPipe = errors.New("broken low-level pipe") + + defaultStartupDecidedFunc = func(_ *strings.Builder, line string) bool { + return true + } + defaultEndLineDecidedFunc = func(sb *strings.Builder, line string) bool { + return true + } + + defaultReadIDFunc = func(serializer ndef.Serializer, data string) (string, error) { + return "", nil + } +) + +type ( + pendingMsg struct { + id string + chWait chan struct{} + callback CallbackFunc + } +) + +type ( + LineFunc func(sb *strings.Builder, line string) bool + ReadIDFunc func(serializer ndef.Serializer, data string) (string, error) + CallbackFunc func(serializer ndef.Serializer, data string) + + Options struct { + Marshaler ndef.Marshaler // 序列化 + Unmarshaler ndef.Unmarshaler // 反序列化 + StartupDecidedFunc LineFunc // 启动决议方法 + EndLineDecidedFunc LineFunc // 行尾判断方法 + ReadIDFunc ReadIDFunc // 从数据中获取ID的方法 + Timeout time.Duration // 超时时间 + } + Option func(opt *Options) +) + +// Processor 处理器 +type Processor struct { + *Options + Context context.Context + cancelFunc context.CancelFunc + + Cmd *exec.Cmd // CMD + stdIn io.WriteCloser // 标准输入通道 + stdOut io.ReadCloser // 标准输出通道 + stdErr io.ReadCloser // 标准错误输出通道,一些程序会在此通道输出错误信息和一般信息 + + isBlock bool // 底层是否为同步逻辑 + isStartup *atomic.Bool // 子进程是否真正启动完毕 + chStart chan struct{} // 程序真正启动完毕信号 + + chSend chan ICommand // 待发送数据 + chWrite chan []byte // 实际发送的数据 + + sendIdx *atomic2.AtomicInt64 // 发送缓冲ID + recIdx *atomic2.AtomicInt64 // 接收缓冲ID + pendingMsgMap map[string]*pendingMsg // 发送缓存区map + + *sync.Mutex +} + +func NewProcessor(block bool, opts ...Option) *Processor { + defaultSerializer := NewPlainSerializer() + tmp := &Processor{ + Options: &Options{ + Marshaler: defaultSerializer, + Unmarshaler: defaultSerializer, + StartupDecidedFunc: defaultStartupDecidedFunc, + EndLineDecidedFunc: defaultEndLineDecidedFunc, + ReadIDFunc: defaultReadIDFunc, + Timeout: DefaultTimeout, + }, + + isBlock: block, + isStartup: &atomic.Bool{}, + chStart: make(chan struct{}), + + chSend: make(chan ICommand, 64), + chWrite: make(chan []byte, 64*8), + sendIdx: atomic2.NewAtomicInt64(), + recIdx: atomic2.NewAtomicInt64(), + pendingMsgMap: make(map[string]*pendingMsg), + Mutex: &sync.Mutex{}, + } + + for _, opt := range opts { + opt(tmp.Options) + } + + if tmp.Timeout == 0 { + tmp.Context, tmp.cancelFunc = context.WithCancel(context.Background()) + } else { + tmp.Context, tmp.cancelFunc = context.WithTimeout(context.Background(), tmp.Timeout) + } + return tmp +} + +func (s *Processor) Run(name string, args ...string) error { + s.Cmd = exec.CommandContext(s.Context, name, args...) + s.stdIn, _ = s.Cmd.StdinPipe() + s.stdOut, _ = s.Cmd.StdoutPipe() + s.stdErr, _ = s.Cmd.StderrPipe() + + setProcessGroupID(s.Cmd) + + err := s.Cmd.Start() + + go func() { + err := s.Cmd.Wait() + if err != nil { + log.Println(fmt.Sprintf("错误:命令行 %+v", err)) + } + _ = s.stdErr.Close() + _ = s.stdIn.Close() + _ = s.stdOut.Close() + }() + s.listen() + + return err +} + +// Listen 开始监听 +func (s *Processor) listen() { + go s.handle(s.stdOut, "stdOut") + go s.handle(s.stdErr, "stdErr") + + // 等待程序启动完毕 + select { + case <-s.chStart: + } + + go s.writeLoop() +} + +func (s *Processor) Stop() error { + s.Lock() + defer s.Unlock() + + s.cancelFunc() + return terminateProcess(s.Cmd.Process.Pid) +} + +// Exec 异步执行命令 +func (s *Processor) Exec(data ICommand, callback CallbackFunc) (err error) { + _, err = s.exec(data, false, callback) + return +} + +// ExecAsync 同步执行命令 +func (s *Processor) ExecAsync(data ICommand, callback CallbackFunc) error { + pm, err := s.exec(data, true, callback) + if err != nil { + return err + } + + // 同步等待消息回复 + <-pm.chWait + + return err +} + +func (s *Processor) exec(data ICommand, withWait bool, callback CallbackFunc) (pm *pendingMsg, err error) { + defer func() { + if e := recover(); e != nil { + err = ErrBrokenPipe + _ = s.sendIdx.DecrementAndGet() + } + }() + + cID := data.MessageID() + if len(cID) == 0 { + if s.isBlock { + // block-1 + cID = fmt.Sprintf("%s-%d", DefaultBlockPrefix, s.sendIdx.IncrementAndGet()) + } else { + cID = fmt.Sprintf("%s-%s-%d", DefaultNonBlockPrefix, cID, s.sendIdx.IncrementAndGet()) + // error + return nil, errors.New("异步底层必须消息必须传递消息ID") + } + } + + var chWait chan struct{} + if withWait { + chWait = make(chan struct{}) + } + pm = &pendingMsg{ + id: cID, + chWait: chWait, + callback: callback, + } + + s.Lock() + s.pendingMsgMap[pm.id] = pm + s.Unlock() + + s.chSend <- data + return pm, err +} + +func (s *Processor) writeLoop() { + defer func() { + close(s.chSend) + close(s.chWrite) + }() + + for { + select { + case <-s.Context.Done(): + return + case data := <-s.chSend: + var ( + bytes []byte + err error + ) + bytes, err = s.Marshaler.Marshal(data) + if err != nil { + fmt.Println(fmt.Sprintf("序列化失败: %+v", err)) + break + } + s.chWrite <- bytes + case data := <-s.chWrite: + // 实际写入数据 + fmt.Println(fmt.Sprintf("发送数据: [%s]", string(data))) + + data = append(data, '\n') + if _, err := s.stdIn.Write(data); err != nil { + return + } + } + } +} + +func (s *Processor) handle(reader io.ReadCloser, typ string) { + defer func() { + _ = reader.Close() + }() + + buffer := bufio.NewReader(reader) + endLine := false + content := strings.Builder{} + for { + select { + case <-s.Context.Done(): + return + default: + break + } + lineBytes, isPrefix, err := buffer.ReadLine() + if err != nil || err == io.EOF { + fmt.Println(fmt.Sprintf("[%s] 读取数据时发生错误: %v", typ, err)) + break + } + + line, err := s.readBytesString(lineBytes) + if err != nil { + break + } + + if !s.Started() { + fmt.Println(fmt.Sprintf("[%s] 接收普通消息:[%s]", typ, line)) + + // 判断程序成功启动 外部逻辑 + if s.StartupDecidedFunc(&content, line) { + s.storeStarted(true) + s.chStart <- struct{}{} + fmt.Println(fmt.Sprintf("[%s] 启动完毕,等待输出...", typ)) + } + continue + } + + content.WriteString(line) + if !isPrefix && len(line) > 0 { + content.WriteByte('\n') + } + // 最后一行的判定逻辑 + if !endLine && s.EndLineDecidedFunc(&content, line) { + endLine = true + } + + if endLine { + endLine = false + cStr := content.String() + if len(cStr) > 0 { + cStr = cStr[:len(cStr)-1] + } + content.Reset() + + revID, err := s.ReadIDFunc(ndef.NewSerializerWrapper(s.Marshaler, s.Unmarshaler), cStr) + if err != nil { + continue + } + //fmt.Println(fmt.Sprintf("[%s] 接收指令消息:[%s]", typ, cStr)) + + // block 需要自行维护ID,是无法透传的 + if len(revID) <= 0 { + revID = fmt.Sprintf("%s-%d", DefaultBlockPrefix, s.recIdx.IncrementAndGet()) + } + + pending, ok := s.pendingMsgMap[revID] + if !ok { + fmt.Println("找不到已发送数据,无法对应处理数据!") + continue + } + delete(s.pendingMsgMap, revID) + + go func() { + pending.callback(ndef.NewSerializerWrapper(s.Marshaler, s.Unmarshaler), cStr) + if pending.chWait != nil { + pending.chWait <- struct{}{} + } + }() + } + } +} + +func (s *Processor) readBytesString(data []byte) (string, error) { + // 编码 + if nstr.Charset(data) == nstr.GBK { + gbk, err := nstr.ToGBK(data) + if err != nil { + return "", err + } + return string(gbk), nil + } else { + return string(data), nil + } +} + +// Started 是否启动完成 +func (s *Processor) Started() bool { + return s.isStartup.Load() +} + +func (s *Processor) storeStarted(val bool) { + s.isStartup.Store(val) +} diff --git a/nsys/cmdn/response.go b/nsys/cmdn/response.go new file mode 100644 index 0000000..366c74b --- /dev/null +++ b/nsys/cmdn/response.go @@ -0,0 +1,40 @@ +package cmdn + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +type IResponse interface { + // MessageID 消息ID + MessageID() string +} + +type PlainResp struct { + ID string + Command string + Result string + Err error +} + +func (r *PlainResp) MessageID() string { + return r.ID +} + +func (r *PlainResp) GetResult() (string, error) { + reg, _ := regexp.Compile(`\t`) + result := strings.TrimSpace(reg.ReplaceAllString(r.Result, " ")) + + res := strings.Fields(result) + l := len(res) + if l > 0 { + if res[0] == "=" { + return strings.TrimSpace(strings.Join(res[1:], " ")), nil + } else if res[0] == "?" { + return "", errors.New(strings.Join(res[1:], " ")) + } + } + return "", errors.New(fmt.Sprintf("错误(未知应答): %s", r.Err)) +} diff --git a/nsys/cmdn/serializer_plain.go b/nsys/cmdn/serializer_plain.go new file mode 100644 index 0000000..a606e65 --- /dev/null +++ b/nsys/cmdn/serializer_plain.go @@ -0,0 +1,48 @@ +package cmdn + +import ( + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/ndef" + "runtime" + "strings" +) + +type PlainSerializer struct { + sysType string +} + +func NewPlainSerializer() ndef.Serializer { + return &PlainSerializer{ + sysType: runtime.GOOS, + } +} + +func (s *PlainSerializer) Marshal(v interface{}) ([]byte, error) { + ret, ok := v.(*PlainCommand) + if !ok { + return nil, errors.New(fmt.Sprintf("参数类型必须为 %T", PlainCommand{})) + } + // ret arg0 arg1 arg2 ... + // cmd arg0 arg1 arg2 ... + sb := strings.Builder{} + if ret.ID != "" { + sb.WriteString(ret.ID) + sb.WriteString(" ") + } + sb.WriteString(ret.Cmd) + sb.WriteString(" ") + sb.WriteString(strings.Join(ret.Args, " ")) + return []byte(sb.String()), nil +} + +func (s *PlainSerializer) Unmarshal(data []byte, v interface{}) error { + t, ok := v.(*PlainResp) + if !ok { + return errors.New(fmt.Sprintf("参数类型必须为 %T", PlainResp{})) + } + t.ID = "" + //t.Command + t.Result = string(data) + return nil +} diff --git a/nsys/cmdr/cmd.go b/nsys/cmdr/cmd.go new file mode 100644 index 0000000..96281ce --- /dev/null +++ b/nsys/cmdr/cmd.go @@ -0,0 +1,449 @@ +package cmdr + +import ( + "context" + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ncli/cmdline" + "io" + "os" + "os/exec" + "path/filepath" +) + +// Cmd struct +type Cmd struct { + *exec.Cmd + // Name of the command + Name string + // DryRun if True, not real execute command + DryRun bool + // Vars mapping + Vars map[string]string + + // BeforeRun hook + BeforeRun func(c *Cmd) + // AfterRun hook + AfterRun func(c *Cmd, err error) +} + +// NewGitCmd instance +func NewGitCmd(subCmd string, args ...string) *Cmd { + return NewCmd("git", subCmd).AddArgs(args) +} + +// NewCmdline instance +// +// see exec.Command +func NewCmdline(line string) *Cmd { + bin, args := cmdline.NewParser(line).WithParseEnv().BinAndArgs() + return NewCmd(bin, args...) +} + +// NewCmd instance +// +// see exec.Command +func NewCmd(bin string, args ...string) *Cmd { + return WrapGoCmd(exec.Command(bin, args...)) +} + +// CmdWithCtx create new instance with context. +// +// see exec.CommandContext +func CmdWithCtx(ctx context.Context, bin string, args ...string) *Cmd { + return WrapGoCmd(exec.CommandContext(ctx, bin, args...)) +} + +// WrapGoCmd instance +func WrapGoCmd(cmd *exec.Cmd) *Cmd { + return &Cmd{ + Cmd: cmd, + Vars: make(map[string]string), + } +} + +// ------------------------------------------------- +// config the command +// ------------------------------------------------- + +// Config the command +func (c *Cmd) Config(fn func(c *Cmd)) *Cmd { + fn(c) + return c +} + +// WithDryRun on exec command +func (c *Cmd) WithDryRun(dryRun bool) *Cmd { + c.DryRun = dryRun + return c +} + +// PrintCmdline on exec command +func (c *Cmd) PrintCmdline() *Cmd { + c.BeforeRun = PrintCmdline + return c +} + +// OnBefore exec add hook +func (c *Cmd) OnBefore(fn func(c *Cmd)) *Cmd { + c.BeforeRun = fn + return c +} + +// OnAfter exec add hook +func (c *Cmd) OnAfter(fn func(c *Cmd, err error)) *Cmd { + c.AfterRun = fn + return c +} + +// WithBin name returns the current object +func (c *Cmd) WithBin(name string) *Cmd { + c.Args[0] = name + c.lookPath(name) + return c +} + +func (c *Cmd) lookPath(name string) { + if filepath.Base(name) == name { + lp, err := exec.LookPath(name) + if lp != "" { + // Update cmd.Path even if err is non-nil. + // If err is ErrDot (especially on Windows), lp may include a resolved + // extension (like .exe or .bat) that should be preserved. + c.Path = lp + } + if err != nil { + panic(fmt.Sprintf("cmdr: look %q path error: %v", name, err)) + } + } +} + +// WithGoCmd and returns the current instance. +func (c *Cmd) WithGoCmd(ec *exec.Cmd) *Cmd { + c.Cmd = ec + return c +} + +// WithWorkDir returns the current object +func (c *Cmd) WithWorkDir(dir string) *Cmd { + c.Dir = dir + return c +} + +// WorkDirOnNE set workdir on input is not empty +func (c *Cmd) WorkDirOnNE(dir string) *Cmd { + if dir == "" { + c.Dir = dir + } + return c +} + +// WithEnvMap override set new ENV for run +func (c *Cmd) WithEnvMap(mp map[string]string) *Cmd { + if ln := len(mp); ln > 0 { + c.Env = make([]string, 0, ln) + for key, val := range mp { + c.Env = append(c.Env, key+"="+val) + } + } + return c +} + +// AppendEnv to the os ENV for run command +func (c *Cmd) AppendEnv(mp map[string]string) *Cmd { + if len(mp) > 0 { + // init env data + if c.Env == nil { + c.Env = os.Environ() + } + + for name, val := range mp { + c.Env = append(c.Env, name+"="+val) + } + } + + return c +} + +// OutputToOS output to OS stdout and error +func (c *Cmd) OutputToOS() *Cmd { + return c.ToOSStdoutStderr() +} + +// ToOSStdoutStderr output to OS stdout and error +func (c *Cmd) ToOSStdoutStderr() *Cmd { + c.Stdout = os.Stdout + c.Stderr = os.Stderr + return c +} + +// ToOSStdout output to OS stdout +func (c *Cmd) ToOSStdout() *Cmd { + c.Stdout = os.Stdout + c.Stderr = os.Stdout + return c +} + +// WithStdin returns the current argument +func (c *Cmd) WithStdin(in io.Reader) *Cmd { + c.Stdin = in + return c +} + +// WithOutput returns the current instance +func (c *Cmd) WithOutput(out, errOut io.Writer) *Cmd { + c.Stdout = out + if errOut != nil { + c.Stderr = errOut + } + return c +} + +// WithAnyArgs add args and returns the current object. +func (c *Cmd) WithAnyArgs(args ...any) *Cmd { + c.Args = append(c.Args, narr.SliceToStrings(args)...) + return c +} + +// AddArg add args and returns the current object +func (c *Cmd) AddArg(args ...string) *Cmd { return c.WithArg(args...) } + +// WithArg add args and returns the current object. alias of the WithArg() +func (c *Cmd) WithArg(args ...string) *Cmd { + c.Args = append(c.Args, args...) + return c +} + +// AddArgf add args and returns the current object. alias of the WithArgf() +func (c *Cmd) AddArgf(format string, args ...any) *Cmd { + return c.WithArgf(format, args...) +} + +// WithArgf add arg and returns the current object +func (c *Cmd) WithArgf(format string, args ...any) *Cmd { + c.Args = append(c.Args, fmt.Sprintf(format, args...)) + return c +} + +// ArgIf add arg and returns the current object +func (c *Cmd) ArgIf(arg string, exprOk bool) *Cmd { + if exprOk { + c.Args = append(c.Args, arg) + } + return c +} + +// WithArgIf add arg and returns the current object +func (c *Cmd) WithArgIf(arg string, exprOk bool) *Cmd { + return c.ArgIf(arg, exprOk) +} + +// AddArgs for the git. alias of WithArgs() +func (c *Cmd) AddArgs(args []string) *Cmd { return c.WithArgs(args) } + +// WithArgs for the git +func (c *Cmd) WithArgs(args []string) *Cmd { + if len(args) > 0 { + c.Args = append(c.Args, args...) + } + return c +} + +// WithArgsIf add arg and returns the current object +func (c *Cmd) WithArgsIf(args []string, exprOk bool) *Cmd { + if exprOk && len(args) > 0 { + c.Args = append(c.Args, args...) + } + return c +} + +// WithVars add vars and returns the current object +func (c *Cmd) WithVars(vs map[string]string) *Cmd { + if len(vs) > 0 { + c.Vars = vs + } + return c +} + +// SetVar add var and returns the current object +func (c *Cmd) SetVar(name, val string) *Cmd { + c.Vars[name] = val + return c +} + +// ------------------------------------------------- +// helper command +// ------------------------------------------------- + +// IDString of the command +func (c *Cmd) IDString() string { + if c.Name != "" { + return c.Name + } + return c.BinOrPath() +} + +// BinName of the command +func (c *Cmd) BinName() string { + if len(c.Args) > 0 { + return c.Args[0] + } + return "" +} + +// BinOrPath of the command +func (c *Cmd) BinOrPath() string { + if len(c.Args) > 0 { + return c.Args[0] + } + return c.Path +} + +// OnlyArgs of the command, not contains bin name. +func (c *Cmd) OnlyArgs() (ss []string) { + if len(c.Args) > 1 { + return c.Args[1:] + } + return +} + +// ResetArgs for command, but will keep bin name. +func (c *Cmd) ResetArgs() { + if len(c.Args) > 0 { + c.Args = c.Args[0:1] + } else { + c.Args = c.Args[:0] + } +} + +// Workdir of the command +func (c *Cmd) Workdir() string { + return c.Dir +} + +// Cmdline to command line +func (c *Cmd) Cmdline() string { + return cmdline.Cmdline(c.Args) +} + +// Copy new instance from current command, with new args. +func (c *Cmd) Copy(args ...string) *Cmd { + nc := *c + + // copy bin name. + if len(c.Args) > 0 { + nc.Args = append([]string{c.Args[0]}, args...) + } else { + nc.Args = args + } + + return &nc +} + +// GoCmd get exec.Cmd +func (c *Cmd) GoCmd() *exec.Cmd { return c.Cmd } + +// ------------------------------------------------- +// run command +// ------------------------------------------------- + +// Success run and return whether success +func (c *Cmd) Success() bool { + return c.Run() == nil +} + +// HasStdout output setting. +func (c *Cmd) HasStdout() bool { + return c.Stdout != nil +} + +// SafeLines run and return output as lines +func (c *Cmd) SafeLines() []string { + ss, _ := c.OutputLines() + return ss +} + +// OutputLines run and return output as lines +func (c *Cmd) OutputLines() ([]string, error) { + out, err := c.Output() + if err != nil { + return nil, err + } + return OutputLines(out), err +} + +// SafeOutput run and return output +func (c *Cmd) SafeOutput() string { + out, err := c.Output() + if err != nil { + return "" + } + return out +} + +// Output run and return output +func (c *Cmd) Output() (string, error) { + if c.BeforeRun != nil { + c.BeforeRun(c) + } + + if c.DryRun { + return "DRY-RUN: ok", nil + } + + output, err := c.Cmd.Output() + + if c.AfterRun != nil { + c.AfterRun(c, err) + } + return string(output), err +} + +// CombinedOutput run and return output, will combine stderr and stdout output +func (c *Cmd) CombinedOutput() (string, error) { + if c.BeforeRun != nil { + c.BeforeRun(c) + } + + if c.DryRun { + return "DRY-RUN: ok", nil + } + + output, err := c.Cmd.CombinedOutput() + + if c.AfterRun != nil { + c.AfterRun(c, err) + } + return string(output), err +} + +// MustRun a command. will panic on error +func (c *Cmd) MustRun() { + if err := c.Run(); err != nil { + panic(err) + } +} + +// FlushRun runs command and flush output to stdout +func (c *Cmd) FlushRun() error { + return c.ToOSStdoutStderr().Run() +} + +// Run runs command +func (c *Cmd) Run() error { + if c.BeforeRun != nil { + c.BeforeRun(c) + } + + if c.DryRun { + return nil + } + + // do running + err := c.Cmd.Run() + + if c.AfterRun != nil { + c.AfterRun(c, err) + } + return err +} diff --git a/nsys/cmdr/cmd_test.go b/nsys/cmdr/cmd_test.go new file mode 100644 index 0000000..ce3b523 --- /dev/null +++ b/nsys/cmdr/cmd_test.go @@ -0,0 +1,35 @@ +package cmdr_test + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nsys/cmdr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestNewCmd(t *testing.T) { + c := cmdr.NewCmd("ls"). + WithArg("-l"). + WithArgs([]string{"-h"}). + AddArg("-a"). + AddArgf("%s", "./") + + assert.Eq(t, "ls", c.BinName()) + assert.Eq(t, "ls", c.IDString()) + assert.StrContains(t, "ls", c.BinOrPath()) + assert.NotContains(t, c.OnlyArgs(), "ls") + + c.OnBefore(func(c *cmdr.Cmd) { + assert.Eq(t, "ls -l -h -a ./", c.Cmdline()) + }) + + out := c.SafeOutput() + fmt.Println(out) + assert.NotEmpty(t, out) + assert.NotEmpty(t, cmdr.OutputLines(out)) + assert.NotEmpty(t, cmdr.FirstLine(out)) + + c.ResetArgs() + assert.Len(t, c.Args, 1) + assert.Empty(t, c.OnlyArgs()) +} diff --git a/nsys/cmdr/cmdr.go b/nsys/cmdr/cmdr.go new file mode 100644 index 0000000..5d1253a --- /dev/null +++ b/nsys/cmdr/cmdr.go @@ -0,0 +1,32 @@ +package cmdr + +import ( + "github.com/gookit/color" + "strings" +) + +// PrintCmdline on before exec +func PrintCmdline(c *Cmd) { + if c.DryRun { + color.Yellowln("DRY-RUN>", c.Cmdline()) + } else { + color.Yellowln(">", c.Cmdline()) + } +} + +// OutputLines split output to lines +func OutputLines(output string) []string { + output = strings.TrimSuffix(output, "\n") + if output == "" { + return nil + } + return strings.Split(output, "\n") +} + +// FirstLine from command output +func FirstLine(output string) string { + if i := strings.Index(output, "\n"); i >= 0 { + return output[0:i] + } + return output +} diff --git a/nsys/cmdr/runner.go b/nsys/cmdr/runner.go new file mode 100644 index 0000000..14639f3 --- /dev/null +++ b/nsys/cmdr/runner.go @@ -0,0 +1,316 @@ +package cmdr + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/ncli/cmdline" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nstr/textutil" + "github.com/gookit/color" + "strings" +) + +// Task struct +type Task struct { + err error + index int + + // ID for task + ID string + Cmd *Cmd + + // BeforeRun hook + BeforeRun func(t *Task) + PrevCond func(prev *Task) bool +} + +// NewTask instance +func NewTask(cmd *Cmd) *Task { + return &Task{ + Cmd: cmd, + } +} + +// get task id by cmd.Name +func (t *Task) ensureID(idx int) { + t.index = idx + if t.ID != "" { + return + } + + id := t.Cmd.IDString() + if t.Cmd.Name == "" { + id += nmath.String(idx) + } + t.ID = id +} + +var rpl = textutil.NewVarReplacer("$").DisableFlatten() + +// RunWith command +func (t *Task) RunWith(ctx nmap.Data) error { + cmdVars := ctx.StringMap("cmdVars") + + if len(cmdVars) > 0 { + // rpl := strutil.NewReplacer(cmdVars) + for i, val := range t.Cmd.Args { + if strings.ContainsRune(val, '$') { + t.Cmd.Args[i] = rpl.RenderSimple(val, cmdVars) + } + } + } + + return t.Run() +} + +// Run command +func (t *Task) Run() error { + if t.BeforeRun != nil { + t.BeforeRun(t) + } + + t.err = t.Cmd.Run() + return t.err +} + +// Err get +func (t *Task) Err() error { + return t.err +} + +// Index get +func (t *Task) Index() int { + return t.index +} + +// Cmdline get +func (t *Task) Cmdline() string { + return t.Cmd.Cmdline() +} + +// IsSuccess of task +func (t *Task) IsSuccess() bool { + return t.err == nil +} + +// RunnerHookFn func +type RunnerHookFn func(r *Runner, t *Task) bool + +// Runner use for batch run multi task commands +type Runner struct { + prev *Task + // task name to index + idMap map[string]int + tasks []*Task + // Errs on run tasks, key is Task.ID + Errs nmap.ErrMap + + // TODO Concurrent run + + // Workdir common workdir + Workdir string + // EnvMap will append to task.Cmd on run + EnvMap map[string]string + + // Params for add custom params + Params nmap.Map + + // DryRun dry run all commands + DryRun bool + // OutToStd stdout and stderr + OutToStd bool + // IgnoreErr continue on error + IgnoreErr bool + // BeforeRun hooks on each task. return false to skip current task. + BeforeRun func(r *Runner, t *Task) bool + // AfterRun hook on each task. return false to stop running. + AfterRun func(r *Runner, t *Task) bool +} + +// NewRunner instance with config func +func NewRunner(fns ...func(rr *Runner)) *Runner { + rr := &Runner{ + idMap: make(map[string]int, 0), + tasks: make([]*Task, 0), + Errs: make(nmap.ErrMap), + Params: make(nmap.Map), + } + + rr.OutToStd = true + for _, fn := range fns { + fn(rr) + } + return rr +} + +// WithOutToStd set +func (r *Runner) WithOutToStd() *Runner { + r.OutToStd = true + return r +} + +// Add multitask at once +func (r *Runner) Add(tasks ...*Task) *Runner { + for _, task := range tasks { + r.AddTask(task) + } + return r +} + +// AddTask add one task +func (r *Runner) AddTask(task *Task) *Runner { + if task.Cmd == nil { + panic("task command cannot be empty") + } + + idx := len(r.tasks) + task.ensureID(idx) + + // TODO check id repeat + r.idMap[task.ID] = idx + r.tasks = append(r.tasks, task) + return r +} + +// AddCmd commands +func (r *Runner) AddCmd(cmds ...*Cmd) *Runner { + for _, cmd := range cmds { + r.AddTask(&Task{Cmd: cmd}) + } + return r +} + +// GitCmd quick a git command task +func (r *Runner) GitCmd(subCmd string, args ...string) *Runner { + return r.AddTask(&Task{ + Cmd: NewGitCmd(subCmd, args...), + }) +} + +// CmdWithArgs a command task +func (r *Runner) CmdWithArgs(cmdName string, args ...string) *Runner { + return r.AddTask(&Task{ + Cmd: NewCmd(cmdName, args...), + }) +} + +// CmdWithAnys a command task +func (r *Runner) CmdWithAnys(cmdName string, args ...any) *Runner { + return r.AddTask(&Task{ + Cmd: NewCmd(cmdName, narr.SliceToStrings(args)...), + }) +} + +// AddCmdline as a command task +func (r *Runner) AddCmdline(line string) *Runner { + bin, args := cmdline.NewParser(line).BinAndArgs() + + return r.AddTask(&Task{ + Cmd: NewCmd(bin, args...), + }) +} + +// Run all tasks +func (r *Runner) Run() error { + // do run tasks + for i, task := range r.tasks { + if r.BeforeRun != nil && !r.BeforeRun(r, task) { + continue + } + + if r.prev != nil && task.PrevCond != nil && !task.PrevCond(r.prev) { + continue + } + + if r.DryRun { + color.Infof("DRY-RUN: task#%d execute completed\n\n", i+1) + continue + } + + if !r.RunTask(task) { + break + } + fmt.Println() // with newline. + } + + if len(r.Errs) == 0 { + return nil + } + return r.Errs +} + +// StepRun one command +func (r *Runner) StepRun() error { + return nil // TODO +} + +// RunTask command +func (r *Runner) RunTask(task *Task) (goon bool) { + if len(r.EnvMap) > 0 { + task.Cmd.AppendEnv(r.EnvMap) + } + + if r.OutToStd && !task.Cmd.HasStdout() { + task.Cmd.ToOSStdoutStderr() + } + + // common workdir + if r.Workdir != "" && task.Cmd.Dir == "" { + task.Cmd.WithWorkDir(r.Workdir) + } + + // do running + if err := task.RunWith(r.Params); err != nil { + r.Errs[task.ID] = err + color.Errorf("Task#%d run error: %s\n", task.Index()+1, err) + + // not ignore error, stop. + if !r.IgnoreErr { + return false + } + } + + if r.AfterRun != nil && !r.AfterRun(r, task) { + return false + } + + // store prev + r.prev = task + return true +} + +// Len of tasks +func (r *Runner) Len() int { + return len(r.tasks) +} + +// Reset instance +func (r *Runner) Reset() *Runner { + r.prev = nil + r.tasks = make([]*Task, 0) + r.idMap = make(map[string]int, 0) + return r +} + +// TaskIDs get +func (r *Runner) TaskIDs() []string { + ss := make([]string, 0, len(r.idMap)) + for id := range r.idMap { + ss = append(ss, id) + } + return ss +} + +// Prev task instance after running +func (r *Runner) Prev() *Task { + return r.prev +} + +// Task get by id name +func (r *Runner) Task(id string) (*Task, error) { + if idx, ok := r.idMap[id]; ok { + return r.tasks[idx], nil + } + return nil, fmt.Errorf("task %q is not exists", id) +} diff --git a/nsys/cmdr/runner_test.go b/nsys/cmdr/runner_test.go new file mode 100644 index 0000000..f4efb83 --- /dev/null +++ b/nsys/cmdr/runner_test.go @@ -0,0 +1,41 @@ +package cmdr_test + +import ( + "bytes" + "fmt" + "git.noahlan.cn/noahlan/ntool/nsys/cmdr" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestRunner_Run(t *testing.T) { + buf := new(bytes.Buffer) + rr := cmdr.NewRunner() + + rr.Add(&cmdr.Task{ + ID: "task1", + Cmd: cmdr.NewCmd("id").WithOutput(buf, buf), + }) + rr.AddCmd(cmdr.NewCmd("ls"). + AddArgs([]string{"-l", "-h"}). + WithOutput(buf, buf). + OnBefore(cmdr.PrintCmdline)) + + task, err := rr.Task("task1") + assert.NoErr(t, err) + assert.NoErr(t, task.Err()) + assert.True(t, task.IsSuccess()) + + ids := rr.TaskIDs() + // dump.P(rr.TaskIDs()) + assert.Len(t, ids, 2) + assert.NotEmpty(t, ids) + assert.Contains(t, ids, task.ID) + + err = rr.Run() + assert.NoErr(t, err) + assert.NoErr(t, rr.Errs.One()) + assert.True(t, rr.Errs.IsEmpty()) + + fmt.Println(buf.String()) +} diff --git a/nsys/exec.go b/nsys/exec.go new file mode 100644 index 0000000..5781c0c --- /dev/null +++ b/nsys/exec.go @@ -0,0 +1,71 @@ +package nsys + +import ( + "bytes" + "git.noahlan.cn/noahlan/ntool/ncli/cmdline" + "git.noahlan.cn/noahlan/ntool/nsys/cmdr" + "os/exec" +) + +// NewCmd instance +func NewCmd(bin string, args ...string) *cmdr.Cmd { + return cmdr.NewCmd(bin, args...) +} + +// FlushExec instance +func FlushExec(bin string, args ...string) error { + return cmdr.NewCmd(bin, args...).FlushRun() +} + +// QuickExec quick exec an simple command line +func QuickExec(cmdLine string, workDir ...string) (string, error) { + return ExecLine(cmdLine, workDir...) +} + +// ExecLine quick exec an command line string +func ExecLine(cmdLine string, workDir ...string) (string, error) { + p := cmdline.NewParser(cmdLine) + + // create a new Cmd instance + cmd := p.NewExecCmd() + if len(workDir) > 0 { + cmd.Dir = workDir[0] + } + + bs, err := cmd.Output() + return string(bs), err +} + +// ExecCmd a command and return output. +// +// Usage: +// +// ExecCmd("ls", []string{"-al"}) +func ExecCmd(binName string, args []string, workDir ...string) (string, error) { + // create a new Cmd instance + cmd := exec.Command(binName, args...) + if len(workDir) > 0 { + cmd.Dir = workDir[0] + } + + bs, err := cmd.Output() + return string(bs), err +} + +// ShellExec exec command by shell cmdLine. eg: "ls -al" +func ShellExec(cmdLine string, shells ...string) (string, error) { + // shell := "/bin/sh" + shell := "sh" + if len(shells) > 0 { + shell = shells[0] + } + + var out bytes.Buffer + cmd := exec.Command(shell, "-c", cmdLine) + cmd.Stdout = &out + + if err := cmd.Run(); err != nil { + return "", err + } + return out.String(), nil +} diff --git a/nsys/retry/retry.go b/nsys/retry/retry.go new file mode 100644 index 0000000..b620557 --- /dev/null +++ b/nsys/retry/retry.go @@ -0,0 +1,87 @@ +package retry + +import ( + "context" + "errors" + "fmt" + "reflect" + "runtime" + "strings" + "time" +) + +const ( + // DefaultRetryTimes times of retry + DefaultRetryTimes = 3 + // DefaultRetryDuration time duration of two retries + DefaultRetryDuration = time.Second * 1 +) + +// Config is config for retry +type Config struct { + context context.Context + retryTimes uint + retryDuration time.Duration +} + +// Func is function that retry executes +type Func func() error + +// Option is for adding retry config +type Option func(*Config) + +// WithTimes set times of retry. +func WithTimes(n uint) Option { + return func(rc *Config) { + rc.retryTimes = n + } +} + +// WithDuration set duration of retries. +func WithDuration(d time.Duration) Option { + return func(rc *Config) { + rc.retryDuration = d + } +} + +// WithContext set retry context config. +func WithContext(ctx context.Context) Option { + return func(rc *Config) { + rc.context = ctx + } +} + +// Retry executes the retryFunc repeatedly until it was successful or canceled by the context +// The default times of retries is 3 and the default duration between retries is 1 seconds. +func Retry(retryFunc Func, opts ...Option) error { + config := &Config{ + retryTimes: DefaultRetryTimes, + retryDuration: DefaultRetryDuration, + context: context.TODO(), + } + + for _, opt := range opts { + opt(config) + } + + var i uint + for i < config.retryTimes { + err := retryFunc() + if err != nil { + select { + case <-time.After(config.retryDuration): + case <-config.context.Done(): + return errors.New("retry is cancelled") + } + } else { + return nil + } + i++ + } + + funcPath := runtime.FuncForPC(reflect.ValueOf(retryFunc).Pointer()).Name() + lastSlash := strings.LastIndex(funcPath, "/") + funcName := funcPath[lastSlash+1:] + + return fmt.Errorf("function %s run failed after %d times retry", funcName, i) +} diff --git a/nsys/retry/retry_test.go b/nsys/retry/retry_test.go new file mode 100644 index 0000000..592a75c --- /dev/null +++ b/nsys/retry/retry_test.go @@ -0,0 +1,73 @@ +package retry + +import ( + "context" + "errors" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" + "time" +) + +func TestRetryFailed(t *testing.T) { + + var number int + increaseNumber := func() error { + number++ + return errors.New("error occurs") + } + + err := Retry(increaseNumber, WithDuration(time.Microsecond*50)) + assert.NotNil(t, err) + assert.Equal(t, DefaultRetryTimes, number) +} + +func TestRetrySucceeded(t *testing.T) { + var number int + increaseNumber := func() error { + number++ + if number == DefaultRetryTimes { + return nil + } + return errors.New("error occurs") + } + + err := Retry(increaseNumber, WithDuration(time.Microsecond*50)) + + assert.Nil(t, err) + assert.Equal(t, DefaultRetryTimes, number) +} + +func TestSetRetryTimes(t *testing.T) { + + var number int + increaseNumber := func() error { + number++ + return errors.New("error occurs") + } + + err := Retry(increaseNumber, WithDuration(time.Microsecond*50), WithTimes(3)) + + assert.NotNil(t, err) + assert.Equal(t, 3, number) +} + +func TestCancelRetry(t *testing.T) { + ctx, cancel := context.WithCancel(context.TODO()) + var number int + increaseNumber := func() error { + number++ + if number > 3 { + cancel() + } + return errors.New("error occurs") + } + + err := Retry(increaseNumber, + WithDuration(time.Microsecond*50), + WithContext(ctx), + WithTimes(100), + ) + + assert.NotNil(t, err) + assert.Equal(t, 4, number) +} diff --git a/nsys/sysenv.go b/nsys/sysenv.go new file mode 100644 index 0000000..a0ce031 --- /dev/null +++ b/nsys/sysenv.go @@ -0,0 +1,184 @@ +package nsys + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "git.noahlan.cn/noahlan/ntool/ncli" + "golang.org/x/term" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" +) + +// IsMSys msys(MINGW64) env,不一定支持颜色 +func IsMSys() bool { + // "MSYSTEM=MINGW64" + return len(os.Getenv("MSYSTEM")) > 0 +} + +// IsConsole check out is in stderr/stdout/stdin +// +// Usage: +// +// sysutil.IsConsole(os.Stdout) +func IsConsole(out io.Writer) bool { + o, ok := out.(*os.File) + if !ok { + return false + } + + fd := o.Fd() + + // fix: cannot use 'o == os.Stdout' to compare + return fd == uintptr(syscall.Stdout) || fd == uintptr(syscall.Stdin) || fd == uintptr(syscall.Stderr) +} + +// IsTerminal isatty check +// +// Usage: +// +// sysutil.IsTerminal(os.Stdout.Fd()) +func IsTerminal(fd uintptr) bool { + // return isatty.IsTerminal(fd) // "github.com/mattn/go-isatty" + return term.IsTerminal(int(fd)) +} + +// StdIsTerminal os.Stdout is terminal +func StdIsTerminal() bool { + return IsTerminal(os.Stdout.Fd()) +} + +// Hostname is alias of os.Hostname, but ignore error +func Hostname() string { + name, _ := os.Hostname() + return name +} + +// CurrentShell get current used shell env file. +// +// eg "/bin/zsh" "/bin/bash". +// if onlyName=true, will return "zsh", "bash" +func CurrentShell(onlyName bool) (path string) { + return ncli.CurrentShell(onlyName) +} + +// HasShellEnv has shell env check. +// +// Usage: +// +// HasShellEnv("sh") +// HasShellEnv("bash") +func HasShellEnv(shell string) bool { + // can also use: "echo $0" + out, err := ShellExec("echo OK", shell) + if err != nil { + return false + } + + return strings.TrimSpace(out) == "OK" +} + +// IsShellSpecialVar reports whether the character identifies a special +// shell variable such as $*. +func IsShellSpecialVar(c uint8) bool { + switch c { + case '*', '#', '$', '@', '!', '?', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return true + } + return false +} + +// FindExecutable in the system +// +// Usage: +// +// sysutil.FindExecutable("bash") +func FindExecutable(binName string) (string, error) { + return exec.LookPath(binName) +} + +// Executable find in the system, alias of FindExecutable() +// +// Usage: +// +// sysutil.Executable("bash") +func Executable(binName string) (string, error) { + return exec.LookPath(binName) +} + +// HasExecutable in the system +// +// Usage: +// +// HasExecutable("bash") +func HasExecutable(binName string) bool { + _, err := exec.LookPath(binName) + return err == nil +} + +// GetEnv get ENV value by key name, can with default value +func GetEnv(name string, def ...string) string { + val := os.Getenv(name) + if val == "" && len(def) > 0 { + val = def[0] + } + return val +} + +// Environ like os.Environ, but will returns key-value map[string]string data. +func Environ() map[string]string { return common.Environ() } + +// EnvMapWith like os.Environ, but will return key-value map[string]string data. +func EnvMapWith(newEnv map[string]string) map[string]string { + envMp := common.Environ() + for name, value := range newEnv { + envMp[name] = value + } + return envMp +} + +// EnvPaths get and split $PATH to []string +func EnvPaths() []string { + return filepath.SplitList(os.Getenv("PATH")) +} + +// SearchPath search executable files in the system $PATH +// +// Usage: +// +// sysutil.SearchPath("go") +func SearchPath(keywords string, limit int) []string { + path := os.Getenv("PATH") + ptn := "*" + keywords + "*" + list := make([]string, 0) + + checked := make(map[string]bool) + for _, dir := range filepath.SplitList(path) { + // Unix shell semantics: path element "" means "." + if dir == "" { + dir = "." + } + + // mark dir is checked + if _, ok := checked[dir]; ok { + continue + } + + checked[dir] = true + matches, err := filepath.Glob(filepath.Join(dir, ptn)) + if err == nil && len(matches) > 0 { + list = append(list, matches...) + size := len(list) + + // limit result size + if limit > 0 && size >= limit { + list = list[:limit] + break + } + } + } + + return list +} diff --git a/nsys/sysutil.go b/nsys/sysutil.go new file mode 100644 index 0000000..6288714 --- /dev/null +++ b/nsys/sysutil.go @@ -0,0 +1,47 @@ +// Package nsys provide some system util functions. eg: sysenv, exec, user, process +package nsys + +import ( + "git.noahlan.cn/noahlan/ntool/internal/common" + "os" + "path/filepath" +) + +// Workdir get +func Workdir() string { + return common.Workdir() +} + +// BinDir get +func BinDir() string { + return filepath.Dir(os.Args[0]) +} + +// BinName get +func BinName() string { + return filepath.Base(os.Args[0]) +} + +// BinFile get +func BinFile() string { + return os.Args[0] +} + +// Open file or url address +func Open(fileOrURL string) error { + return OpenURL(fileOrURL) +} + +// OpenBrowser file or url address +func OpenBrowser(fileOrURL string) error { + return OpenURL(fileOrURL) +} + +// OpenFile opens new browser window for the file path. +func OpenFile(path string) error { + fpath, err := filepath.Abs(path) + if err != nil { + return err + } + return OpenURL("file://" + fpath) +} diff --git a/nsys/sysutil_darwin.go b/nsys/sysutil_darwin.go new file mode 100644 index 0000000..11a292b --- /dev/null +++ b/nsys/sysutil_darwin.go @@ -0,0 +1,36 @@ +package nsys + +import "os/exec" + +// IsWin system. linux windows darwin +func IsWin() bool { return false } + +// IsWindows system. linux windows darwin +func IsWindows() bool { return false } + +// IsMac system +func IsMac() bool { return true } + +// IsDarwin system +func IsDarwin() bool { return true } + +// IsLinux system +func IsLinux() bool { return false } + +// OpenURL Open browser URL +// +// Mac: +// +// open 'https://github.com/inhere' +// +// Linux: +// +// xdg-open URL +// x-www-browser 'https://github.com/inhere' +// +// Windows: +// +// cmd /c start https://github.com/inhere +func OpenURL(URL string) error { + return exec.Command("open", URL).Run() +} diff --git a/nsys/sysutil_nonwin.go b/nsys/sysutil_nonwin.go new file mode 100644 index 0000000..a4da154 --- /dev/null +++ b/nsys/sysutil_nonwin.go @@ -0,0 +1,17 @@ +//go:build !windows + +package nsys + +import ( + "syscall" +) + +// Kill a process by pid +func Kill(pid int, signal syscall.Signal) error { + return syscall.Kill(pid, signal) +} + +// ProcessExists check process exists by pid +func ProcessExists(pid int) bool { + return nil == syscall.Kill(pid, 0) +} diff --git a/nsys/sysutil_test.go b/nsys/sysutil_test.go new file mode 100644 index 0000000..66d1ae3 --- /dev/null +++ b/nsys/sysutil_test.go @@ -0,0 +1,23 @@ +package nsys_test + +import ( + "git.noahlan.cn/noahlan/ntool/nsys" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "os" + "runtime" + "testing" +) + +func TestBasic_usage(t *testing.T) { + assert.NotEmpty(t, nsys.BinDir()) + assert.NotEmpty(t, nsys.BinFile()) +} + +func TestProcessExists(t *testing.T) { + if runtime.GOOS != "windows" { + pid := os.Getpid() + assert.True(t, nsys.ProcessExists(pid)) + } else { + t.Skip("on Windows") + } +} diff --git a/nsys/sysutil_unix.go b/nsys/sysutil_unix.go new file mode 100644 index 0000000..8fe1a17 --- /dev/null +++ b/nsys/sysutil_unix.go @@ -0,0 +1,54 @@ +//go:build !windows && !darwin + +package nsys + +import ( + "os/exec" + "strings" +) + +// IsWin system. linux windows darwin +func IsWin() bool { return false } + +// IsWindows system. linux windows darwin +func IsWindows() bool { return false } + +// IsMac system +func IsMac() bool { return false } + +// IsDarwin system +func IsDarwin() bool { return false } + +// IsLinux system +func IsLinux() bool { + return true +} + +// There are multiple possible providers to open a browser on linux +// One of them is xdg-open, another is x-www-browser, then there's www-browser, etc. +// Look for one that exists and run it +var openBins = []string{"xdg-open", "x-www-browser", "www-browser"} + +// OpenURL Open file or browser URL +// +// Mac: +// +// open 'https://github.com/inhere' +// +// Linux: +// +// xdg-open URL +// x-www-browser 'https://github.com/inhere' +// +// Windows: +// +// cmd /c start https://github.com/inhere +func OpenURL(URL string) error { + for _, bin := range openBins { + if _, err := exec.LookPath(bin); err == nil { + return exec.Command(bin, URL).Run() + } + } + + return &exec.Error{Name: strings.Join(openBins, ","), Err: exec.ErrNotFound} +} diff --git a/nsys/sysutil_windows.go b/nsys/sysutil_windows.go new file mode 100644 index 0000000..40ba7c2 --- /dev/null +++ b/nsys/sysutil_windows.go @@ -0,0 +1,56 @@ +//go:build windows + +package nsys + +import ( + "errors" + "syscall" + + "golang.org/x/sys/windows" +) + +// IsWin system. linux windows darwin +func IsWin() bool { return true } + +// IsWindows system. linux windows darwin +func IsWindows() bool { return true } + +// IsMac system +func IsMac() bool { return false } + +// IsDarwin system +func IsDarwin() bool { return false } + +// IsLinux system +func IsLinux() bool { return false } + +// Kill a process by pid +func Kill(pid int, signal syscall.Signal) error { + return errors.New("not support") +} + +// ProcessExists check process exists by pid +func ProcessExists(pid int) bool { + panic("TIP: please use sysutil/process.Exists()") +} + +// OpenURL Open file or browser URL +// +// - refers https://github.com/pkg/browser +// +// Mac: +// +// open 'https://github.com/inhere' +// +// Linux: +// +// xdg-open URL +// x-www-browser 'https://github.com/inhere' +// +// Windows: +// +// cmd /c start https://github.com/inhere +func OpenURL(url string) error { + // return exec.Command("cmd", "/C", "start", URL).Run() + return windows.ShellExecute(0, nil, windows.StringToUTF16Ptr(url), nil, nil, windows.SW_SHOWNORMAL) +} diff --git a/ntest/assert/assert.go b/ntest/assert/assert.go new file mode 100644 index 0000000..5182a67 --- /dev/null +++ b/ntest/assert/assert.go @@ -0,0 +1,8 @@ +package assert + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Helper() + Name() string + Error(args ...any) +} diff --git a/ntest/assert/assertions.go b/ntest/assert/assertions.go new file mode 100644 index 0000000..cf54d45 --- /dev/null +++ b/ntest/assert/assertions.go @@ -0,0 +1,30 @@ +package assert + +// Assertions provides assertion methods around the TestingT interface. +type Assertions struct { + t TestingT + ok bool // last assert result + // prefix message for each assert TODO + Msg string +} + +// New makes a new Assertions object for the specified TestingT. +func New(t TestingT) *Assertions { + return &Assertions{t: t} +} + +// WithMsg set with prefix message. +func (as *Assertions) WithMsg(msg string) *Assertions { + as.Msg = msg + return as +} + +// IsOk for last check +func (as *Assertions) IsOk() bool { + return as.ok +} + +// IsFail for last check +func (as *Assertions) IsFail() bool { + return !as.ok +} diff --git a/ntest/assert/assertions_methods.go b/ntest/assert/assertions_methods.go new file mode 100644 index 0000000..1982169 --- /dev/null +++ b/ntest/assert/assertions_methods.go @@ -0,0 +1,243 @@ +package assert + +func (as *Assertions) Nil(give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Nil(as.t, give, fmtAndArgs...) + return as +} + +// NotNil asserts that the given is a not nil value +func (as *Assertions) NotNil(val any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = NotNil(as.t, val, fmtAndArgs...) + return as +} + +// True check, please see True() +func (as *Assertions) True(give bool, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = True(as.t, give, fmtAndArgs...) + return as +} + +// False check, please see False() +func (as *Assertions) False(give bool, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = False(as.t, give, fmtAndArgs...) + return as +} + +// Empty check, please see Empty() +func (as *Assertions) Empty(give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Empty(as.t, give, fmtAndArgs...) + return as +} + +// NotEmpty check, please see NotEmpty() +func (as *Assertions) NotEmpty(give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = NotEmpty(as.t, give, fmtAndArgs...) + return as +} + +// Panics check, please see Panics() +func (as *Assertions) Panics(fn PanicRunFunc, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Panics(as.t, fn, fmtAndArgs...) + return as +} + +// NotPanics check, please see NotPanics() +func (as *Assertions) NotPanics(fn PanicRunFunc, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = NotPanics(as.t, fn, fmtAndArgs...) + return as +} + +// PanicsMsg check, please see PanicsMsg() +func (as *Assertions) PanicsMsg(fn PanicRunFunc, wantVal any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = PanicsMsg(as.t, fn, wantVal, fmtAndArgs...) + return as +} + +// PanicsErrMsg check, please see PanicsErrMsg() +func (as *Assertions) PanicsErrMsg(fn PanicRunFunc, errMsg string, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = PanicsErrMsg(as.t, fn, errMsg, fmtAndArgs...) + return as +} + +// Contains asserts that the given data(string,slice,map) should contain element +func (as *Assertions) Contains(src, elem any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Contains(as.t, src, elem, fmtAndArgs...) + return as +} + +// NotContains asserts that the given data(string,slice,map) should not contain element +func (as *Assertions) NotContains(src, elem any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = NotContains(as.t, src, elem, fmtAndArgs...) + return as +} + +// ContainsKey asserts that the given map is contains key +func (as *Assertions) ContainsKey(mp, key any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = ContainsKey(as.t, mp, key, fmtAndArgs...) + return as +} + +// StrContains asserts that the given strings is contains sub-string +func (as *Assertions) StrContains(s, sub string, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = StrContains(as.t, s, sub, fmtAndArgs...) + return as +} + +// NoErr asserts that the given is a nil error +func (as *Assertions) NoErr(err error, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = NoErr(as.t, err, fmtAndArgs...) + return as +} + +// Err asserts that the given is a not nil error +func (as *Assertions) Err(err error, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Err(as.t, err, fmtAndArgs...) + return as +} + +// Error asserts that the given is a not nil error +func (as *Assertions) Error(err error, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Err(as.t, err, fmtAndArgs...) + return as +} + +// ErrIs asserts that the given error is equals wantErr +func (as *Assertions) ErrIs(err, wantErr error, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = ErrIs(as.t, err, wantErr, fmtAndArgs...) + return as +} + +// ErrMsg asserts that the given is a not nil error and error message equals wantMsg +func (as *Assertions) ErrMsg(err error, errMsg string, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = ErrMsg(as.t, err, errMsg, fmtAndArgs...) + return as +} + +// ErrSubMsg asserts that the given is a not nil error and the error message contains subMsg +func (as *Assertions) ErrSubMsg(err error, subMsg string, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = ErrSubMsg(as.t, err, subMsg, fmtAndArgs...) + return as +} + +// Len assert given length is equals to wantLn +func (as *Assertions) Len(give any, wantLn int, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Len(as.t, give, wantLn, fmtAndArgs...) + return as +} + +// LenGt assert given length is greater than to minLn +func (as *Assertions) LenGt(give any, minLn int, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = LenGt(as.t, give, minLn, fmtAndArgs...) + return as +} + +// Eq asserts that the want should equal to the given +func (as *Assertions) Eq(want, give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Eq(as.t, want, give, fmtAndArgs...) + return as +} + +// Equal asserts that the want should equal to the given +// +// Alias of Eq() +func (as *Assertions) Equal(want, give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Eq(as.t, want, give, fmtAndArgs...) + return as +} + +// Neq asserts that the want should not be equal to the given. +// alias of NotEq() +func (as *Assertions) Neq(want, give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Neq(as.t, want, give, fmtAndArgs...) + return as +} + +// NotEq asserts that the want should not be equal to the given +func (as *Assertions) NotEq(want, give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = NotEq(as.t, want, give, fmtAndArgs...) + return as +} + +// NotEqual asserts that the want should not be equal to the given +// +// Alias of NotEq() +func (as *Assertions) NotEqual(want, give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = NotEq(as.t, want, give, fmtAndArgs...) + return as +} + +// Lt asserts that the give(intX) should not be less than max +func (as *Assertions) Lt(give, max any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Lt(as.t, give, max, fmtAndArgs...) + return as +} + +// Lte asserts that the give(intX) should not be less than or equal to max +func (as *Assertions) Lte(give, max any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Lte(as.t, give, max, fmtAndArgs...) + return as +} + +// Gt asserts that the give(intX) should not be greater than min +func (as *Assertions) Gt(give, min any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Gt(as.t, give, min, fmtAndArgs...) + return as +} + +// Gte asserts that the give(intX) should not be greater than or equal to min +func (as *Assertions) Gte(give, min any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Gte(as.t, give, min, fmtAndArgs...) + return as +} + +// IsType type equals assert +func (as *Assertions) IsType(wantType, give any, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = IsType(as.t, wantType, give, fmtAndArgs...) + return as +} + +// Fail reports a failure through +func (as *Assertions) Fail(failMsg string, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = Fail(as.t, failMsg, fmtAndArgs...) + return as +} + +// FailNow fails test +func (as *Assertions) FailNow(failMsg string, fmtAndArgs ...any) *Assertions { + as.t.Helper() + as.ok = FailNow(as.t, failMsg, fmtAndArgs...) + return as +} diff --git a/ntest/assert/assertions_test.go b/ntest/assert/assertions_test.go new file mode 100644 index 0000000..834c3d3 --- /dev/null +++ b/ntest/assert/assertions_test.go @@ -0,0 +1,43 @@ +package assert_test + +import ( + "errors" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestAssertions_Chain(t *testing.T) { + // err := "error message" + err := errors.New("error message") + + as := assert.New(t). + NotEmpty(err). + NotNil(err). + Err(err). + ErrMsg(err, "error message"). + Eq("error message", err.Error()). + Neq("message", err.Error()). + Equal("error message", err.Error()). + Contains(err.Error(), "message"). + StrContains(err.Error(), "message"). + NotContains(err.Error(), "success"). + Gt(4, 3). + Lt(2, 3) + + assert.True(t, as.IsOk()) + assert.False(t, as.IsFail()) + + iv := 23 + as = assert.New(t). + IsType(1, iv). + NotEq(22, iv). + NotEqual(22, iv). + Lte(iv, 23). + Gte(iv, 23). + Empty(0). + True(true). + False(false). + Nil(nil) + + assert.True(t, as.IsOk()) +} diff --git a/ntest/assert/asserts.go b/ntest/assert/asserts.go new file mode 100644 index 0000000..bdb8fc5 --- /dev/null +++ b/ntest/assert/asserts.go @@ -0,0 +1,730 @@ +package assert + +import ( + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/narr" + "git.noahlan.cn/noahlan/ntool/nmap" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nreflect" + "git.noahlan.cn/noahlan/ntool/nstd" + "github.com/gookit/color" + "reflect" + "runtime/debug" + "strings" +) + +// Nil asserts that the given is a nil value +func Nil(t TestingT, give any, fmtAndArgs ...any) bool { + if nstd.IsNil(give) { + return true + } + + t.Helper() + return fail(t, fmt.Sprintf("Expected nil, but got: %#v", give), fmtAndArgs) +} + +// NotNil asserts that the given is a not nil value +func NotNil(t TestingT, give any, fmtAndArgs ...any) bool { + if !nstd.IsNil(give) { + return true + } + + t.Helper() + return fail(t, "Should not nil value", fmtAndArgs) +} + +// True asserts that the given is a bool true +func True(t TestingT, give bool, fmtAndArgs ...any) bool { + if !give { + t.Helper() + return fail(t, "Result should be True", fmtAndArgs) + } + return true +} + +// False asserts that the given is a bool false +func False(t TestingT, give bool, fmtAndArgs ...any) bool { + if give { + t.Helper() + return fail(t, "Result should be False", fmtAndArgs) + } + return true +} + +// Empty asserts that the give should be empty +func Empty(t TestingT, give any, fmtAndArgs ...any) bool { + empty := isEmpty(give) + if !empty { + t.Helper() + return fail(t, fmt.Sprintf("Should be empty, but was:\n%#v", give), fmtAndArgs) + } + + return empty +} + +// NotEmpty asserts that the give should not be empty +func NotEmpty(t TestingT, give any, fmtAndArgs ...any) bool { + nEmpty := !isEmpty(give) + if !nEmpty { + t.Helper() + return fail(t, fmt.Sprintf("Should not be empty, but was:\n%#v", give), fmtAndArgs) + } + + return nEmpty +} + +// PanicRunFunc define +type PanicRunFunc func() + +// didPanic returns true if the function passed to it panics. Otherwise, it returns false. +func runPanicFunc(f PanicRunFunc) (didPanic bool, message any, stack string) { + didPanic = true + defer func() { + message = recover() + if didPanic { + stack = string(debug.Stack()) + } + }() + + // call the target function + f() + didPanic = false + + return +} + +// Panics asserts that the code inside the specified func panics. +func Panics(t TestingT, fn PanicRunFunc, fmtAndArgs ...any) bool { + if hasPanic, panicVal, _ := runPanicFunc(fn); !hasPanic { + t.Helper() + + return fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", fn, panicVal), fmtAndArgs) + } + + return true +} + +// NotPanics asserts that the code inside the specified func NOT panics. +func NotPanics(t TestingT, fn PanicRunFunc, fmtAndArgs ...any) bool { + if hasPanic, panicVal, stackMsg := runPanicFunc(fn); hasPanic { + t.Helper() + + return fail(t, fmt.Sprintf( + "func %#v should not panic\n\tPanic value:\t%#v\n\tPanic stack:\t%s", + fn, panicVal, stackMsg, + ), fmtAndArgs, + ) + } + + return true +} + +// PanicsMsg should panic and with a value +func PanicsMsg(t TestingT, fn PanicRunFunc, wantVal any, fmtAndArgs ...any) bool { + hasPanic, panicVal, stackMsg := runPanicFunc(fn) + if !hasPanic { + t.Helper() + return fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", fn, panicVal), fmtAndArgs) + } + + if panicVal != wantVal { + t.Helper() + return fail(t, fmt.Sprintf( + "func %#v should panic.\n\tWant value:\t%#v\n\tPanic value:\t%#v\n\tPanic stack:\t%s", + fn, wantVal, panicVal, stackMsg), + fmtAndArgs, + ) + } + + return true +} + +// PanicsErrMsg should panic and with error message +func PanicsErrMsg(t TestingT, fn PanicRunFunc, errMsg string, fmtAndArgs ...any) bool { + hasPanic, panicVal, stackMsg := runPanicFunc(fn) + if !hasPanic { + t.Helper() + return fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", fn, panicVal), fmtAndArgs) + } + + err, ok := panicVal.(error) + if !ok { + t.Helper() + return fail(t, fmt.Sprintf("func %#v should panic and is error type,\nbut type was: %T", fn, panicVal), fmtAndArgs) + } + + if err.Error() != errMsg { + t.Helper() + return fail(t, fmt.Sprintf( + "func %#v should panic.\n\tWant error:\t%#v\n\tPanic value:\t%#v\n\tPanic stack:\t%s", + fn, errMsg, panicVal, stackMsg), + fmtAndArgs, + ) + } + + return true +} + +// Contains asserts that the given data(string,slice,map) should contain element +// +// TIP: only support types: string, map, array, slice +// +// map - check key exists +// string - check sub-string exists +// array,slice - check sub-element exists +func Contains(t TestingT, src, elem any, fmtAndArgs ...any) bool { + valid, found := nstd.CheckContains(src, elem) + if valid && found { + return true + } + + t.Helper() + + // src invalid + if !valid { + return fail(t, fmt.Sprintf("%#v could not be applied builtin len()", src), fmtAndArgs) + } + + // not found + return fail(t, fmt.Sprintf("%#v\nShould contain: %#v", src, elem), fmtAndArgs) +} + +// NotContains asserts that the given data(string,slice,map) should not contain element +// +// TIP: only support types: string, map, array, slice +// +// map - check key exists +// string - check sub-string exists +// array,slice - check sub-element exists +func NotContains(t TestingT, src, elem any, fmtAndArgs ...any) bool { + valid, found := nstd.CheckContains(src, elem) + if valid && !found { + return true + } + + t.Helper() + + // src invalid + if !valid { + return fail(t, fmt.Sprintf("%#v could not be applied builtin len()", src), fmtAndArgs) + } + + // found + return fail(t, fmt.Sprintf("%#v\nShould not contain: %#v", src, elem), fmtAndArgs) +} + +// ContainsKey asserts that the given map is contains key +func ContainsKey(t TestingT, mp, key any, fmtAndArgs ...any) bool { + if !nmap.HasKey(mp, key) { + t.Helper() + return fail(t, + fmt.Sprintf( + "Map should contains the key: %#v\nMap data:\n%v", + key, + nmap.FormatIndent(mp, " "), + ), + fmtAndArgs, + ) + } + + return true +} + +// NotContainsKey asserts that the given map is not contains key +func NotContainsKey(t TestingT, mp, key any, fmtAndArgs ...any) bool { + if nmap.HasKey(mp, key) { + t.Helper() + return fail(t, + fmt.Sprintf( + "Map should not contains the key: %#v\nMap data:\n%v", + key, + nmap.FormatIndent(mp, " "), + ), + fmtAndArgs, + ) + } + + return true +} + +// ContainsKeys asserts that the map is contains all given keys +// +// Usage: +// +// ContainsKeys(t, map[string]any{...}, []string{"key1", "key2"}) +func ContainsKeys(t TestingT, mp any, keys any, fmtAndArgs ...any) bool { + anyKeys, err := narr.AnyToSlice(keys) + if err != nil { + t.Helper() + return fail(t, err.Error(), fmtAndArgs) + } + + ok, noKey := nmap.HasAllKeys(mp, anyKeys...) + if !ok { + t.Helper() + return fail(t, + fmt.Sprintf( + "Map should contains the key: %#v\nMap data:\n%v", + noKey, + nmap.FormatIndent(mp, " "), + ), + fmtAndArgs, + ) + } + + return true +} + +// NotContainsKeys asserts that the map is not contains all given keys +// +// Usage: +// +// NotContainsKeys(t, map[string]any{...}, []string{"key1", "key2"}) +func NotContainsKeys(t TestingT, mp any, keys any, fmtAndArgs ...any) bool { + anyKeys, err := narr.AnyToSlice(keys) + if err != nil { + t.Helper() + return fail(t, err.Error(), fmtAndArgs) + } + + ok, hasKey := nmap.HasOneKey(mp, anyKeys...) + if ok { + t.Helper() + return fail(t, + fmt.Sprintf( + "Map should not contains the key: %#v\nMap data:\n%v", + hasKey, + nmap.FormatIndent(mp, " "), + ), + fmtAndArgs, + ) + } + + return true +} + +// StrContains asserts that the given strings is contains sub-string +func StrContains(t TestingT, s, sub string, fmtAndArgs ...any) bool { + if strings.Contains(s, sub) { + return true + } + + t.Helper() + return fail(t, + fmt.Sprintf("String check fail:\nGiven string: %#v\nNot contains: %#v", s, sub), + fmtAndArgs, + ) +} + +// StrCount asserts that the given strings is contains sub-string and count +func StrCount(t TestingT, s, sub string, count int, fmtAndArgs ...any) bool { + if strings.Count(s, sub) == count { + return true + } + + t.Helper() + return fail(t, + fmt.Sprintf("String check fail:\nGiven string: %s\nNot contains %q count: %d", s, sub, count), + fmtAndArgs, + ) +} + +// +// -------------------- error -------------------- +// + +// NoError asserts that the given is a nil error. alias of NoError() +func NoError(t TestingT, err error, fmtAndArgs ...any) bool { + t.Helper() + return NoErr(t, err, fmtAndArgs...) +} + +// NoErr asserts that the given is a nil error +func NoErr(t TestingT, err error, fmtAndArgs ...any) bool { + if err != nil { + t.Helper() + return fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), fmtAndArgs) + } + return true +} + +// Error asserts that the given is a not nil error. alias of Error() +func Error(t TestingT, err error, fmtAndArgs ...any) bool { + t.Helper() + return Err(t, err, fmtAndArgs...) +} + +// Err asserts that the given is a not nil error +func Err(t TestingT, err error, fmtAndArgs ...any) bool { + if err == nil { + t.Helper() + return fail(t, "An error is expected but got nil.", fmtAndArgs) + } + return true +} + +// ErrIs asserts that the given error is equals wantErr +func ErrIs(t TestingT, err, wantErr error, fmtAndArgs ...any) bool { + if err == nil { + t.Helper() + return fail(t, "An error is expected but got nil.", fmtAndArgs) + } + + if !errors.Is(err, wantErr) { + t.Helper() + return fail(t, fmt.Sprintf("Expect given err is equals %#v.", wantErr), fmtAndArgs) + } + + return true +} + +// ErrMsg asserts that the given is a not nil error and error message equals wantMsg +func ErrMsg(t TestingT, err error, wantMsg string, fmtAndArgs ...any) bool { + if err == nil { + t.Helper() + return fail(t, "An error is expected but got nil.", fmtAndArgs) + } + + errMsg := err.Error() + if errMsg != wantMsg { + t.Helper() + return fail(t, fmt.Sprintf("Error message not equal:\n"+ + "expect: %q\n"+ + "actual: %q", wantMsg, errMsg), fmtAndArgs) + } + + return true +} + +// ErrSubMsg asserts that the given is a not nil error and the error message contains subMsg +func ErrSubMsg(t TestingT, err error, subMsg string, fmtAndArgs ...any) bool { + if err == nil { + t.Helper() + return fail(t, "An error is expected but got nil.", fmtAndArgs) + } + + errMsg := err.Error() + if !strings.Contains(errMsg, subMsg) { + t.Helper() + return fail(t, fmt.Sprintf("Error message check fail:\n"+ + "error message : %q\n"+ + "should contains: %q", errMsg, subMsg), fmtAndArgs) + } + + return true +} + +// +// -------------------- Len -------------------- +// + +// Len assert given length is equals to wantLn +func Len(t TestingT, give any, wantLn int, fmtAndArgs ...any) bool { + gln := nreflect.Len(reflect.ValueOf(give)) + if gln < 0 { + t.Helper() + return fail(t, fmt.Sprintf("\"%s\" could not be calc length", give), fmtAndArgs) + } + + if gln != wantLn { + t.Helper() + return fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", give, wantLn, gln), fmtAndArgs) + } + return false +} + +// LenGt assert given length is greater than to minLn +func LenGt(t TestingT, give any, minLn int, fmtAndArgs ...any) bool { + gln := nreflect.Len(reflect.ValueOf(give)) + if gln < 0 { + t.Helper() + return fail(t, fmt.Sprintf("\"%s\" could not be calc length", give), fmtAndArgs) + } + + if gln < minLn { + t.Helper() + return fail(t, fmt.Sprintf("\"%s\" should less have %d item(s), but has %d", give, minLn, gln), fmtAndArgs) + } + return false +} + +// +// -------------------- compare -------------------- +// + +// Equal asserts that the want should equal to the given. +// +// alias of Eq() +func Equal(t TestingT, want, give any, fmtAndArgs ...any) bool { + t.Helper() + return Eq(t, want, give, fmtAndArgs...) +} + +// Eq asserts that the want should equal to the given +func Eq(t TestingT, want, give any, fmtAndArgs ...any) bool { + t.Helper() + + if err := checkEqualArgs(want, give); err != nil { + return fail(t, + fmt.Sprintf("Cannot compare: %#v == %#v (%s)", want, give, err), + fmtAndArgs, + ) + } + + if !nreflect.IsEqual(want, give) { + // TODO diff := diff(want, give) + want, give = formatUnequalValues(want, give) + return fail(t, fmt.Sprintf("Not equal: \n"+ + "expect: %s\n"+ + "actual: %s", want, give), fmtAndArgs) + } + + return true +} + +// Neq asserts that the want should not be equal to the given. +// +// alias of NotEq() +func Neq(t TestingT, want, give any, fmtAndArgs ...any) bool { + t.Helper() + return NotEq(t, want, give, fmtAndArgs...) +} + +// NotEqual asserts that the want should not be equal to the given. +// +// alias of NotEq() +func NotEqual(t TestingT, want, give any, fmtAndArgs ...any) bool { + t.Helper() + return NotEq(t, want, give, fmtAndArgs...) +} + +// NotEq asserts that the want should not be equal to the given +func NotEq(t TestingT, want, give any, fmtAndArgs ...any) bool { + t.Helper() + + if err := checkEqualArgs(want, give); err != nil { + return fail(t, + fmt.Sprintf("Cannot compare: %#v == %#v (%s)", want, give, err), + fmtAndArgs, + ) + } + + if nreflect.IsEqual(want, give) { + return fail(t, fmt.Sprintf("Given should not be: %#v\n", give), fmtAndArgs) + } + return true +} + +// EqualValues asserts that two objects are equal or convertable to the same types +// and equal. +// +// assert.EqualValues(t, uint32(123), int32(123)) +func EqualValues(t TestingT, expected, actual any, fmtAndArgs ...any) bool { + t.Helper() + + if !nreflect.IsEqualValues(expected, actual) { + //diff := diff(expected, actual) + expected, actual = formatUnequalValues(expected, actual) + return fail(t, fmt.Sprintf("Not equal: \n"+ + "expected: %s\n"+ + "actual : %s", expected, actual), fmtAndArgs) + } + + return true + +} + +// Lt asserts that the give(intX,uintX,floatX) should not be less than max +func Lt(t TestingT, give, max any, fmtAndArgs ...any) bool { + if nmath.Compare(give, max, "lt") { + return true + } + + t.Helper() + return fail(t, fmt.Sprintf("Given %v should later than %v", give, max), fmtAndArgs) +} + +// Lte asserts that the give(intX,uintX,floatX) should not be less than or equals to max +func Lte(t TestingT, give, max any, fmtAndArgs ...any) bool { + if nmath.Compare(give, max, "lte") { + return true + } + + t.Helper() + return fail(t, fmt.Sprintf("Given %v should later than %v", give, max), fmtAndArgs) +} + +// Gt asserts that the give(intX,uintX,floatX) should not be greater than min +func Gt(t TestingT, give, min any, fmtAndArgs ...any) bool { + if nmath.Compare(give, min, "gt") { + return true + } + + t.Helper() + return fail(t, fmt.Sprintf("Given %v should gater than %v", give, min), fmtAndArgs) +} + +// Gte asserts that the give(intX,uintX,floatX) should not be greater than or equals to min +func Gte(t TestingT, give, min any, fmtAndArgs ...any) bool { + if nmath.Compare(give, min, "gte") { + return true + } + + t.Helper() + return fail(t, fmt.Sprintf("Given %v should gater than or equal %v", give, min), fmtAndArgs) +} + +// IsType assert data type equals +// +// Usage: +// +// assert.IsType(t, 0, val) // assert type is int +func IsType(t TestingT, wantType, give any, fmtAndArgs ...any) bool { + if nreflect.IsEqual(reflect.TypeOf(wantType), reflect.TypeOf(give)) { + return true + } + + t.Helper() + return fail(t, + fmt.Sprintf("Expected to be of type %v, but was %v", reflect.TypeOf(wantType), reflect.TypeOf(give)), + fmtAndArgs, + ) +} + +// IsKind assert data reflect.Kind equals. +// If `give` is ptr or interface, will get real kind. +// +// Usage: +// +// assert.IsKind(t, reflect.Int, val) // assert type is int kind. +func IsKind(t TestingT, wantKind reflect.Kind, give any, fmtAndArgs ...any) bool { + giveKind := nreflect.Elem(reflect.ValueOf(give)).Kind() + if wantKind == giveKind { + return true + } + + t.Helper() + return fail(t, + fmt.Sprintf("Expected to be of kind %v, but was %v", wantKind, giveKind), + fmtAndArgs, + ) +} + +// Same asserts that two pointers reference the same object. +// +// assert.Same(t, ptr1, ptr2) +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func Same(t TestingT, wanted, actual any, fmtAndArgs ...any) bool { + if samePointers(wanted, actual) { + return true + } + + return fail(t, fmt.Sprintf("Not same: \n"+ + "wanted: %p %#v\n"+ + "actual: %p %#v", wanted, wanted, actual, actual), fmtAndArgs) +} + +// NotSame asserts that two pointers do not reference the same object. +// +// assert.NotSame(t, ptr1, ptr2) +// +// Both arguments must be pointer variables. Pointer variable sameness is +// determined based on the equality of both type and value. +func NotSame(t TestingT, want, actual any, fmtAndArgs ...any) bool { + if !samePointers(want, actual) { + return true + } + + t.Helper() + return fail(t, fmt.Sprintf("Expect and actual point to the same object: %p %#v", want, want), fmtAndArgs) +} + +// samePointers compares two generic interface objects and returns whether +// they point to the same object +func samePointers(first, second any) bool { + firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second) + if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr { + return false + } + + firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second) + if firstType != secondType { + return false + } + + // compare pointer addresses + return first == second +} + +// +// -------------------- fail -------------------- +// + +// Fail reports a failure through +func Fail(t TestingT, failMsg string, fmtAndArgs ...any) bool { + t.Helper() + return fail(t, failMsg, fmtAndArgs) +} + +type failNower interface { + FailNow() +} + +// FailNow fails test +func FailNow(t TestingT, failMsg string, fmtAndArgs ...any) bool { + t.Helper() + fail(t, failMsg, fmtAndArgs) + + if fnr, ok := t.(failNower); ok { + fnr.FailNow() + } + return false +} + +// +// -------------------- render error -------------------- +// + +var ( + // ShowFullPath on show error trace + ShowFullPath = true + // EnableColor on show error trace + EnableColor = true +) + +// DisableColor render +func DisableColor() { + EnableColor = false +} + +// HideFullPath render +func HideFullPath() { + ShowFullPath = false +} + +// fail reports a failure through +func fail(t TestingT, failMsg string, fmtAndArgs []any) bool { + t.Helper() + + tName := t.Name() + if EnableColor { + tName = color.Red.Sprint(tName) + } + + labeledTexts := []labeledText{ + {"Test Name", tName}, + {"Error At", strings.Join(callerInfos(), "\n")}, + {"Error Msg", failMsg}, + } + + // user custom message + if userMsg := formatTplAndArgs(fmtAndArgs...); len(userMsg) > 0 { + labeledTexts = append(labeledTexts, labeledText{"User Msg", userMsg}) + } + + t.Error("\n" + formatLabeledTexts(labeledTexts)) + return false +} diff --git a/ntest/assert/asserts_test.go b/ntest/assert/asserts_test.go new file mode 100644 index 0000000..aa4ac36 --- /dev/null +++ b/ntest/assert/asserts_test.go @@ -0,0 +1,34 @@ +package assert_test + +import ( + "errors" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestCommon(t *testing.T) { + assert.Nil(t, nil) + assert.False(t, false) + assert.True(t, true) +} + +func TestErr(t *testing.T) { + err := errors.New("this is a error") + // assert2.EqualError(t, err, "user custom message") + assert.Err(t, err, "user custom message") + assert.ErrMsg(t, err, "this is a error") +} + +func TestContains(t *testing.T) { + str := "abc+123" + assert.StrContains(t, str, "123") +} + +func TestEq(t *testing.T) { + str := "abc" + + assert.Eq(t, "abc", str) + assert.Panics(t, func() { + panic("hh") + }) +} diff --git a/ntest/assert/util.go b/ntest/assert/util.go new file mode 100644 index 0000000..9be5433 --- /dev/null +++ b/ntest/assert/util.go @@ -0,0 +1,190 @@ +package assert + +import ( + "bufio" + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/nfs" + "git.noahlan.cn/noahlan/ntool/nmath" + "git.noahlan.cn/noahlan/ntool/nreflect" + "git.noahlan.cn/noahlan/ntool/nstd/io" + "git.noahlan.cn/noahlan/ntool/nstr" + "github.com/gookit/color" + "reflect" + "runtime" + "strings" + "time" +) + +// isEmpty value check +func isEmpty(v any) bool { + if v == nil { + return true + } + return nreflect.IsEmpty(reflect.ValueOf(v)) +} + +func checkEqualArgs(expected, actual any) error { + if expected == nil && actual == nil { + return nil + } + + if nreflect.IsFunc(expected) || nreflect.IsFunc(actual) { + return errors.New("cannot take func type as argument") + } + return nil +} + +// formatUnequalValues takes two values of arbitrary types and returns string +// representations appropriate to be presented to the user. +// +// If the values are not of like type, the returned strings will be prefixed +// with the type name, and the value will be enclosed in parentheses similar +// to a type conversion in the Go grammar. +func formatUnequalValues(expected, actual any) (e string, a string) { + if reflect.TypeOf(expected) != reflect.TypeOf(actual) { + return truncatingFormat(expected), truncatingFormat(actual) + // return fmt.Sprintf("%T(%s)", expected, truncatingFormat(expected)), + // fmt.Sprintf("%T(%s)", actual, truncatingFormat(actual)) + } + + switch expected.(type) { + case time.Duration: + return fmt.Sprintf("%v", expected), fmt.Sprintf("%v", actual) + } + + return truncatingFormat(expected), truncatingFormat(actual) +} + +// truncatingFormat formats the data and truncates it if it's too long. +// +// This helps keep formatted error messages lines from exceeding the +// bufio.MaxScanTokenSize max line length that the go testing framework imposes. +func truncatingFormat(data any) string { + if data == nil { + return "" + } + + var value string + switch data.(type) { + case string: + value = fmt.Sprintf("string(%q)", data) + default: + value = fmt.Sprintf("%T(%v)", data, data) + } + + // Give us some space the type info too if needed. + max := bufio.MaxScanTokenSize - 100 + if len(value) > max { + value = value[0:max] + "<... truncated>" + } + return value +} + +func formatTplAndArgs(fmtAndArgs ...any) string { + if len(fmtAndArgs) == 0 || fmtAndArgs == nil { + return "" + } + + ln := len(fmtAndArgs) + first := fmtAndArgs[0] + + if ln == 1 { + if msgAsStr, ok := first.(string); ok { + return msgAsStr + } + return fmt.Sprintf("%+v", first) + } + + // is template string. + if tplStr, ok := first.(string); ok { + return fmt.Sprintf(tplStr, fmtAndArgs[1:]...) + } + return fmt.Sprint(fmtAndArgs...) +} + +func callerInfos() []string { + num := 3 + skip := 2 + ss := make([]string, 0, num) + + for i := skip; i < skip+num; i++ { + pc, file, line, ok := runtime.Caller(i) + if !ok { + // The breaks below failed to terminate the loop, and we ran off the + // end of the call stack. + break + } + + fc := runtime.FuncForPC(pc) + if fc == nil { + continue + } + + // This is a huge edge case, but it will panic if this is the case + if file == "" { + continue + } + + fcName := fc.Name() + if fcName == "testing.tRunner" || strings.Contains(fcName, "goutil/testutil/assert") { + continue + } + + // eg: runtime.goexit + if strings.HasPrefix(fcName, "runtime.") { + continue + } + + filePath := file + if !ShowFullPath { + filePath = nfs.Name(filePath) + } + + ss = append(ss, fmt.Sprintf("%s:%d", filePath, line)) + } + + return ss +} + +// refers from stretchr/testify/assert +type labeledText struct { + label string + message string +} + +func formatLabeledTexts(lts []labeledText) string { + labelWidth := 0 + elemSize := len(lts) + for _, lt := range lts { + labelWidth = nmath.MaxInt(len(lt.label), labelWidth) + } + + var sb strings.Builder + for i, lt := range lts { + label := lt.label + if EnableColor { + label = color.Green.Sprint(label) + } + + sb.WriteString(" " + label + nstr.Repeat(" ", labelWidth-len(lt.label)) + ": ") + formatMessage(lt.message, labelWidth, &sb) + if i+1 != elemSize { + sb.WriteByte('\n') + } + } + return sb.String() +} + +func formatMessage(message string, labelWidth int, buf io.StringWriteStringer) string { + for i, scanner := 0, bufio.NewScanner(strings.NewReader(message)); scanner.Scan(); i++ { + // skip add prefix for first line. + if i != 0 { + // +3: is len of ": " + _, _ = buf.WriteString("\n " + strings.Repeat(" ", labelWidth+3)) + } + _, _ = buf.WriteString(scanner.Text()) + } + + return buf.String() +} diff --git a/ntest/buffer.go b/ntest/buffer.go new file mode 100644 index 0000000..72b40e4 --- /dev/null +++ b/ntest/buffer.go @@ -0,0 +1,43 @@ +package ntest + +import ( + "bytes" + "fmt" +) + +// Buffer wrap and extends the bytes.Buffer +type Buffer struct { + bytes.Buffer +} + +// NewBuffer instance +func NewBuffer() *Buffer { + return &Buffer{} +} + +// WriteString rewrite +func (b *Buffer) WriteString(ss ...string) { + for _, s := range ss { + _, _ = b.Buffer.WriteString(s) + } +} + +// WriteAny method +func (b *Buffer) WriteAny(vs ...any) { + for _, v := range vs { + _, _ = b.Buffer.WriteString(fmt.Sprint(v)) + } +} + +// Writeln method +func (b *Buffer) Writeln(s string) { + _, _ = b.Buffer.WriteString(s) + _ = b.Buffer.WriteByte('\n') +} + +// ResetAndGet buffer string. +func (b *Buffer) ResetAndGet() string { + s := b.String() + b.Reset() + return s +} diff --git a/ntest/buffer_test.go b/ntest/buffer_test.go new file mode 100644 index 0000000..75ce15a --- /dev/null +++ b/ntest/buffer_test.go @@ -0,0 +1,20 @@ +package ntest_test + +import ( + "git.noahlan.cn/noahlan/ntool/ntest" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestNewBuffer(t *testing.T) { + buf := ntest.NewBuffer() + + buf.WriteString("ab", "-", "cd") + assert.Eq(t, "ab-cd", buf.ResetAndGet()) + + buf.WriteAny(23, "abc") + assert.Eq(t, "23abc", buf.ResetAndGet()) + + buf.Writeln("abc") + assert.Eq(t, "abc\n", buf.ResetAndGet()) +} diff --git a/ntest/mock/env.go b/ntest/mock/env.go new file mode 100644 index 0000000..985b767 --- /dev/null +++ b/ntest/mock/env.go @@ -0,0 +1,118 @@ +package mock + +import ( + "os" + "strings" +) + +// Env mocking + +// MockEnvValue will store old env value, set new val. will restore old value on end. +func MockEnvValue(key, val string, fn func(nv string)) { + old := os.Getenv(key) + err := os.Setenv(key, val) + if err != nil { + panic(err) + } + + fn(os.Getenv(key)) + + // if old is empty, unset key. + if old == "" { + err = os.Unsetenv(key) + } else { + err = os.Setenv(key, old) + } + if err != nil { + panic(err) + } +} + +// MockEnvValues will store old env value, set new val. will restore old value on end. +func MockEnvValues(kvMap map[string]string, fn func()) { + backups := make(map[string]string, len(kvMap)) + + for key, val := range kvMap { + backups[key] = os.Getenv(key) + _ = os.Setenv(key, val) + } + + fn() + + for key := range kvMap { + if old := backups[key]; old == "" { + _ = os.Unsetenv(key) + } else { + _ = os.Setenv(key, old) + } + } +} + +// MockOsEnvByText by env text string. +// will clear all old ENV data, use given data map. +// will recover old ENV after fn run. +func MockOsEnvByText(envText string, fn func()) { + ss := strings.Split(envText, "\n") + mp := make(map[string]string, len(ss)) + + for _, line := range ss { + if line = strings.TrimSpace(line); line == "" { + continue + } + nodes := strings.SplitN(line, "=", 2) + envKey := strings.TrimSpace(nodes[0]) + + if len(nodes) < 2 { + mp[envKey] = "" + } else { + mp[envKey] = strings.TrimSpace(nodes[1]) + } + } + + MockCleanOsEnv(mp, fn) +} + +// MockOsEnv by env map data. alias of MockCleanOsEnv +func MockOsEnv(mp map[string]string, fn func()) { + MockCleanOsEnv(mp, fn) +} + +// backup os ENV +var envBak = os.Environ() + +// ClearOSEnv info. +// +// Usage: +// +// testutil.ClearOSEnv() +// defer testutil.RevertOSEnv() +// // do something ... +func ClearOSEnv() { os.Clearenv() } + +// RevertOSEnv info +func RevertOSEnv() { + os.Clearenv() + for _, str := range envBak { + nodes := strings.SplitN(str, "=", 2) + _ = os.Setenv(nodes[0], nodes[1]) + } +} + +// MockCleanOsEnv by env map data. +// +// will clear all old ENV data, use given data map. +// will recover old ENV after fn run. +func MockCleanOsEnv(mp map[string]string, fn func()) { + os.Clearenv() + for key, val := range mp { + _ = os.Setenv(key, val) + } + + fn() + + os.Clearenv() + for _, str := range envBak { + nodes := strings.SplitN(str, "=", 2) + _ = os.Setenv(nodes[0], nodes[1]) + } +} diff --git a/ntest/mock/env_test.go b/ntest/mock/env_test.go new file mode 100644 index 0000000..df29b8c --- /dev/null +++ b/ntest/mock/env_test.go @@ -0,0 +1,55 @@ +package mock_test + +import ( + "git.noahlan.cn/noahlan/ntool/nenv" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "git.noahlan.cn/noahlan/ntool/ntest/mock" + "os" + "testing" +) + +func TestMockEnvValue(t *testing.T) { + is := assert.New(t) + is.Eq("", os.Getenv("APP_COMMAND")) + + mock.MockEnvValue("APP_COMMAND", "new val", func(nv string) { + is.Eq("new val", nv) + }) + + shellVal := "custom-value" + mock.MockEnvValue("SHELL", shellVal, func(newVal string) { + is.Eq(shellVal, newVal) + }) + + is.Eq("", os.Getenv("APP_COMMAND")) + is.Panics(func() { + mock.MockEnvValue("invalid=", "value", nil) + }) +} + +func TestMockEnvValues(t *testing.T) { + is := assert.New(t) + is.Eq("", os.Getenv("APP_COMMAND")) + + mock.MockEnvValues(map[string]string{ + "APP_COMMAND": "new val", + }, func() { + is.Eq("new val", os.Getenv("APP_COMMAND")) + }) + + is.Eq("", os.Getenv("APP_COMMAND")) +} + +func TestMockOsEnvByText(t *testing.T) { + envStr := ` +APP_COMMAND = login +APP_ENV = dev +APP_DEBUG = true +` + + mock.MockOsEnvByText(envStr, func() { + assert.Len(t, nenv.Environ(), 3) + assert.Eq(t, "true", os.Getenv("APP_DEBUG")) + assert.Eq(t, "login", os.Getenv("APP_COMMAND")) + }) +} diff --git a/ntest/mock/fs.go b/ntest/mock/fs.go new file mode 100644 index 0000000..b62287b --- /dev/null +++ b/ntest/mock/fs.go @@ -0,0 +1,42 @@ +package mock + +import ( + "git.noahlan.cn/noahlan/ntool/ngo" + "io/fs" + "path" +) + +// DirEnt create a fs.DirEntry +type DirEnt struct { + Nam string + Dir bool + Typ fs.FileMode + Fi fs.FileInfo + Err error +} + +// NewDirEnt create a fs.DirEntry +func NewDirEnt(fpath string, isDir ...bool) *DirEnt { + isd := ngo.FirstOr(isDir, false) + return &DirEnt{Nam: path.Base(fpath), Dir: isd, Typ: fs.ModePerm} +} + +// Name get +func (d *DirEnt) Name() string { + return d.Nam +} + +// IsDir get +func (d *DirEnt) IsDir() bool { + return d.Dir +} + +// Type get +func (d *DirEnt) Type() fs.FileMode { + return d.Typ +} + +// Info get +func (d *DirEnt) Info() (fs.FileInfo, error) { + return d.Fi, d.Err +} diff --git a/ntest/writer.go b/ntest/writer.go new file mode 100644 index 0000000..0bdf246 --- /dev/null +++ b/ntest/writer.go @@ -0,0 +1,63 @@ +package ntest + +import "fmt" + +// TestWriter struct, useful for testing +type TestWriter struct { + Buffer + // ErrOnWrite return error on write, useful for testing + ErrOnWrite bool + // ErrOnFlush return error on flush, useful for testing + ErrOnFlush bool + // ErrOnClose return error on close, useful for testing + ErrOnClose bool +} + +// NewTestWriter instance +func NewTestWriter() *TestWriter { + return &TestWriter{} +} + +// SetErrOnWrite method +func (w *TestWriter) SetErrOnWrite() *TestWriter { + w.ErrOnWrite = true + return w +} + +// SetErrOnFlush method +func (w *TestWriter) SetErrOnFlush() *TestWriter { + w.ErrOnFlush = true + return w +} + +// SetErrOnClose method +func (w *TestWriter) SetErrOnClose() *TestWriter { + w.ErrOnClose = true + return w +} + +// Flush implements +func (w *TestWriter) Flush() error { + if w.ErrOnFlush { + return fmt.Errorf("flush error") + } + + w.Reset() + return nil +} + +// Close implements +func (w *TestWriter) Close() error { + if w.ErrOnClose { + return fmt.Errorf("close error") + } + return nil +} + +// Write implements +func (w *TestWriter) Write(p []byte) (n int, err error) { + if w.ErrOnWrite { + return 0, fmt.Errorf("write error") + } + return w.Buffer.Write(p) +} diff --git a/ntest/writer_test.go b/ntest/writer_test.go new file mode 100644 index 0000000..37b4a3d --- /dev/null +++ b/ntest/writer_test.go @@ -0,0 +1,28 @@ +package ntest_test + +import ( + "git.noahlan.cn/noahlan/ntool/ntest" + "git.noahlan.cn/noahlan/ntool/ntest/assert" + "testing" +) + +func TestNewTestWriter(t *testing.T) { + tw := ntest.NewTestWriter() + _, err := tw.Write([]byte("hello")) + assert.NoErr(t, err) + assert.Eq(t, "hello", tw.String()) + assert.NoErr(t, tw.Flush()) + assert.Eq(t, "", tw.String()) + assert.NoErr(t, tw.Close()) + + tw.SetErrOnWrite() + _, err = tw.Write([]byte("hello")) + assert.Err(t, err) + assert.Eq(t, "", tw.String()) + + tw.SetErrOnFlush() + assert.Err(t, tw.Flush()) + + tw.SetErrOnClose() + assert.Err(t, tw.Close()) +} diff --git a/ntime/check.go b/ntime/check.go new file mode 100644 index 0000000..64a8469 --- /dev/null +++ b/ntime/check.go @@ -0,0 +1 @@ +package ntime diff --git a/ntime/config.go b/ntime/config.go new file mode 100644 index 0000000..649b05c --- /dev/null +++ b/ntime/config.go @@ -0,0 +1,11 @@ +package ntime + +// TimeFormatConf 时间配置 +type TimeFormatConf struct { + // DateTime 日期 时间 格式 + DateTime string `json:",optional,default=2006-01-02 15:04:05"` + // Date 日期格式 + Date string `json:",optional,default=2006-01-02"` + // Time 时间格式 + Time string `json:",optional,default=15:04:05"` +} diff --git a/ntime/format.go b/ntime/format.go new file mode 100644 index 0000000..f579e3a --- /dev/null +++ b/ntime/format.go @@ -0,0 +1,240 @@ +package ntime + +import ( + "errors" + "fmt" + "git.noahlan.cn/noahlan/ntool/ngo" + "regexp" + "strconv" + "strings" + "time" +) + +var ( + // ErrDateLayout error + ErrDateLayout = errors.New("invalid date layout string") + // ErrInvalidParam error + ErrInvalidParam = errors.New("invalid input for parse time") +) + +// Format convert time to string use default layout +func Format(t time.Time) string { + return t.Format(DefaultLayout) +} + +// FormatBy convert time to string use given layout or template(java), +func FormatBy(t time.Time, layout ...string) string { + return t.Format(ToLayout(ngo.FirstOr(layout, DefaultLayout))) +} + +// FormatUnix time seconds use default layout +func FormatUnix(sec int64, layout ...string) string { + return time.Unix(sec, 0).Format(ToLayout(ngo.FirstOr(layout, DefaultLayout))) +} + +var timeFormats = [][]int{ + {0}, + {1}, + {2, 1}, + {60}, + {120, 60}, + {3600}, + {7200, 3600}, + {86400}, + {172800, 86400}, +} + +var timeMessages = []string{ + "< 1 sec", "1 sec", "secs", "1 min", "mins", "1 hr", "hrs", "1 day", "days", +} + +// HowLongAgo format given timestamp to string. +func HowLongAgo(sec int64) string { + intVal := int(sec) + length := len(timeFormats) + + for i, item := range timeFormats { + if intVal >= item[0] { + ni := i + 1 + match := false + + if ni < length { // next exists + next := timeFormats[ni] + if intVal < next[0] { // current <= intVal < next + match = true + } + } else if ni == length { // current is last + match = true + } + + if match { // match success + if len(item) == 1 { + return timeMessages[i] + } + + // len is 2 + return fmt.Sprintf("%d %s", intVal/item[1], timeMessages[i]) + } + } + } + + return "unknown" // He should never happen +} + +// auto match use some commonly layouts. +// key is layout length. +var layoutMap = map[int][]string{ + 6: {"200601", "060102", time.Kitchen}, + 8: {"20060102"}, + 10: {time.DateOnly}, + 13: {"2006-01-02 15"}, + 15: {time.Stamp}, + 16: {"2006-01-02 15:04"}, + 19: {time.DateTime, time.RFC822, time.StampMilli}, + 20: {"2006-01-02 15:04:05Z"}, + 21: {time.RFC822Z}, + 22: {time.StampMicro}, + 23: {"2006-01-02 15:04:05.000", "2006-01-02 15:04:05.999"}, + 24: {time.ANSIC}, + 25: {time.RFC3339, time.StampNano}, + // time.Layout}, // must go >= 1.19 + 26: {"2006-01-02 15:04:05.000000"}, + 28: {time.UnixDate}, + 29: {time.RFC1123, "2006-01-02 15:04:05.000000000"}, + 30: {time.RFC850}, + 31: {time.RFC1123Z}, + 35: {time.RFC3339Nano}, +} + +// MustParseTime must convert date time string to time.Time +// it will return ZeroTime when parsing error occurred. see MustParseTimeD +func MustParseTime(s string, layouts ...string) time.Time { + return MustParseTimeD(s, ZeroTime, layouts...) +} + +// MustParseTimeD must convert date time string to time.Time +// it will return defaultTime when parsing error occurred. see ParseTime +func MustParseTimeD(s string, defaultTime time.Time, layouts ...string) time.Time { + t, err := ParseTime(s, layouts...) + if err != nil { + return defaultTime + } + return t +} + +// ParseTime convert date time string to time.Time +// it will use some commonly layouts when layouts is empty or nil +func ParseTime(s string, layouts ...string) (t time.Time, err error) { + // custom layout + if len(layouts) > 0 { + if len(layouts[0]) > 0 { + return time.Parse(ToLayout(layouts[0]), s) + } + err = ErrDateLayout + return + } + + // auto match use some commonly layouts. + strLn := len(s) + maybeLayouts, ok := layoutMap[strLn] + if !ok { + err = ErrInvalidParam + return + } + + var hasAlphaT bool + if pos := strings.IndexByte(s, 'T'); pos > 0 && pos < 12 { + hasAlphaT = true + } + + hasSlashR := strings.IndexByte(s, '/') > 0 + for _, layout := range maybeLayouts { + // date string has "T". eg: "2006-01-02T15:04:05" + if hasAlphaT { + layout = strings.Replace(layout, " ", "T", 1) + } + + // date string has "/". eg: "2006/01/02 15:04:05" + if hasSlashR { + layout = strings.Replace(layout, "-", "/", -1) + } + + t, err = time.Parse(layout, s) + if err == nil { + return + } + } + + // t, err = time.ParseInLocation(layout, s, time.Local) + return +} + +var ( + // TIP: extend unit d,w + // time.ParseDuration() is not supported. eg: "1d", "2w" + durStrReg = regexp.MustCompile(`^(-?\d+)(ns|us|µs|ms|s|m|h|d|w)$`) + // match long duration string, such as "1hour", "2hours", "3minutes", "4mins", "5days", "1weeks" + // time.ParseDuration() is not supported. + durStrRegL = regexp.MustCompile(`^(-?\d+)([a-zA-Z]{3,})$`) +) + +// IsDuration check the string is a duration string. +func IsDuration(s string) bool { + if s == "0" || durStrReg.MatchString(s) { + return true + } + return durStrRegL.MatchString(s) +} + +// ToDuration parses a duration string. such as "300ms", "-1.5h" or "2h45m". +// Valid time units are "ns", "us" (or "µs"), "ms", "s", "m", "h". +// +// Diff of time.ParseDuration: +// - support extend unit d, w at the end of string. such as "1d", "2w". +// - support long string unit at end. such as "1hour", "2hours", "3minutes", "4mins", "5days", "1weeks". +// +// If the string is not a valid duration string, it will return an error. +func ToDuration(s string) (time.Duration, error) { + ln := len(s) + if ln == 0 { + return 0, fmt.Errorf("empty duration string") + } + + s = strings.ToLower(s) + if s == "0" { + return 0, nil + } + + // extend unit d,w, time.ParseDuration() is not supported. eg: "1d", "2w" + if lastUnit := s[ln-1]; lastUnit == 'd' { + s = s + "ay" + } else if lastUnit == 'w' { + s = s + "eek" + } + + // long unit, time.ParseDuration() is not supported. eg: "-3sec" => [3sec -3 sec] + ss := durStrRegL.FindStringSubmatch(s) + if len(ss) == 3 { + num, unit := ss[1], ss[2] + + // convert to short unit + switch unit { + case "week", "weeks": + // max unit is hour, so need convert by 24 * 7 * n + n, _ := strconv.Atoi(num) + s = strconv.Itoa(n*24*7) + "h" + case "day", "days": + // max unit is hour, so need convert by 24 * n + n, _ := strconv.Atoi(num) + s = strconv.Itoa(n*24) + "h" + case "hour", "hours": + s = num + "h" + case "min", "mins", "minute", "minutes": + s = num + "m" + case "sec", "secs", "second", "seconds": + s = num + "s" + } + } + + return time.ParseDuration(s) +} diff --git a/ntime/gotime.go b/ntime/gotime.go new file mode 100644 index 0000000..de3e883 --- /dev/null +++ b/ntime/gotime.go @@ -0,0 +1,115 @@ +package ntime + +import "time" + +var ( + // DefaultLayout template for format time + DefaultLayout = time.DateTime + // ZeroTime zero time instance + ZeroTime = time.Time{} +) + +// SetLocalByName set local by tz name. eg: UTC, PRC, Local +func SetLocalByName(tzName string) error { + location, err := time.LoadLocation(tzName) + if err != nil { + return err + } + + time.Local = location + return nil +} + +// NowAddDay add some day time from now +func NowAddDay(day int) time.Time { + return time.Now().AddDate(0, 0, day) +} + +// NowAddHour add some hour time from now +func NowAddHour(hour int) time.Time { + return time.Now().Add(time.Duration(hour) * OneHour) +} + +// NowAddMinutes add some minutes time from now +func NowAddMinutes(minutes int) time.Time { + return time.Now().Add(time.Duration(minutes) * OneMin) +} + +// NowAddSec add some seconds time from now. alias of NowAddSeconds() +func NowAddSec(seconds int) time.Time { + return time.Now().Add(time.Duration(seconds) * time.Second) +} + +// NowAddSeconds add some seconds time from now +func NowAddSeconds(seconds int) time.Time { + return time.Now().Add(time.Duration(seconds) * time.Second) +} + +// NowHourStart time +func NowHourStart() time.Time { + return HourStart(time.Now()) +} + +// NowHourEnd time +func NowHourEnd() time.Time { + return HourEnd(time.Now()) +} + +// AddDay add some day time for given time +func AddDay(t time.Time, day int) time.Time { + return t.AddDate(0, 0, day) +} + +// AddHour add some hour time for given time +func AddHour(t time.Time, hour int) time.Time { + return t.Add(time.Duration(hour) * OneHour) +} + +// AddMinutes add some minutes time for given time +func AddMinutes(t time.Time, minutes int) time.Time { + return t.Add(time.Duration(minutes) * OneMin) +} + +// AddSeconds add some seconds time for given time +func AddSeconds(t time.Time, seconds int) time.Time { + return t.Add(time.Duration(seconds) * time.Second) +} + +// AddSec add some seconds time for given time. alias of AddSeconds() +func AddSec(t time.Time, seconds int) time.Time { + return t.Add(time.Duration(seconds) * time.Second) +} + +// HourStart time for given time +func HourStart(t time.Time) time.Time { + y, m, d := t.Date() + return time.Date(y, m, d, t.Hour(), 0, 0, 0, t.Location()) +} + +// HourEnd time for given time +func HourEnd(t time.Time) time.Time { + y, m, d := t.Date() + return time.Date(y, m, d, t.Hour(), 59, 59, int(time.Second-time.Nanosecond), t.Location()) +} + +// DayStart time for given time +func DayStart(t time.Time) time.Time { + y, m, d := t.Date() + return time.Date(y, m, d, 0, 0, 0, 0, t.Location()) +} + +// DayEnd time for given time +func DayEnd(t time.Time) time.Time { + y, m, d := t.Date() + return time.Date(y, m, d, 23, 59, 59, int(time.Second-time.Nanosecond), t.Location()) +} + +// TodayStart time +func TodayStart() time.Time { + return DayStart(time.Now()) +} + +// TodayEnd time +func TodayEnd() time.Time { + return DayEnd(time.Now()) +} diff --git a/ntime/ntime.go b/ntime/ntime.go new file mode 100644 index 0000000..cd8dba8 --- /dev/null +++ b/ntime/ntime.go @@ -0,0 +1,344 @@ +package ntime + +import "time" + +// provide some commonly time const +const ( + OneSecond = 1 + OneMinSec = 60 + OneHourSec = 3600 + OneDaySec = 86400 + OneWeekSec = 7 * 86400 + + Microsecond = time.Microsecond + Millisecond = time.Millisecond + + Second = time.Second + OneMin = time.Minute + Minute = time.Minute + OneHour = time.Hour + Hour = time.Hour + OneDay = 24 * time.Hour + Day = OneDay + OneWeek = 7 * 24 * time.Hour + Week = OneWeek +) + +// NTime alias of Time +type NTime = Time + +// Time an enhanced time.Time implementation. +type Time struct { + time.Time + // Layout set the default date format layout. default use DefaultLayout + Layout string +} + +/************************************************************* + * Create ntime instance + *************************************************************/ + +// Now time instance +func Now() *Time { + return &Time{Time: time.Now(), Layout: DefaultLayout} +} + +// New instance form given time +func New(t time.Time) *Time { + return &Time{Time: t, Layout: DefaultLayout} +} + +// Wrap the go time instance. alias of the New() +func Wrap(t time.Time) *Time { + return &Time{Time: t, Layout: DefaultLayout} +} + +// FromTime new instance form given time.Time. alias of the New() +func FromTime(t time.Time) *Time { + return &Time{Time: t, Layout: DefaultLayout} +} + +// Local time for now +func Local() *Time { + return New(time.Now().In(time.Local)) +} + +// FromUnix create from unix time +func FromUnix(sec int64) *Time { + return New(time.Unix(sec, 0)) +} + +// FromDate create from datetime string. +func FromDate(s string, template ...string) (*Time, error) { + if len(template) > 0 && template[0] != "" { + return FromString(s, ToLayout(template[0])) + } + return FromString(s) +} + +// FromString create from datetime string. see nstr.ToTime() +func FromString(s string, layouts ...string) (*Time, error) { + t, err := ParseTime(s, layouts...) + if err != nil { + return nil, err + } + return New(t), nil +} + +// LocalByName time for now +func LocalByName(tzName string) *Time { + loc, err := time.LoadLocation(tzName) + if err != nil { + panic(err) + } + + return New(time.Now().In(loc)) +} + +/************************************************************* + * Usage + *************************************************************/ + +// T returns the t.Time +func (t *Time) T() time.Time { + return t.Time +} + +// Format returns a textual representation of the time value formatted according to the layout defined by the argument. +// +// Example: +// // all go-style layout is ok +// // some java-style layout is ok +// tn := ntime.Now() +// tn.Format("yyyy-MM-dd HH:mm:ss") // Output: 2019-01-01 12:12:12 +// tn.Format("yyyy-MM-dd HH:mm") // Output: 2019-01-01 12:12 +// tn.Format("yyyy-MM-dd") // Output: 2019-01-01 +// tn.Format("yyyy-MM") // Output: 2019-01 +// tn.Format("yy-MM-dd") // Output: 19-01-01 +// tn.Format("ymmdd") // Output: 190101 +// +// see time.Format() +func (t *Time) Format(template string) string { + if template == "" { + template = t.Layout + } + return t.Time.Format(ToLayout(template)) +} + +// Datetime use DefaultLayout format time to date. see Format() +func (t *Time) Datetime() string { + return t.Format(t.Layout) +} + +// Yesterday get day ago time for the time +func (t *Time) Yesterday() *Time { + return t.AddSeconds(-OneDaySec) +} + +// DayAgo get some day ago time for the time +func (t *Time) DayAgo(day int) *Time { + return t.AddSeconds(-day * OneDaySec) +} + +// AddDay add some day time for the time +func (t *Time) AddDay(day int) *Time { + return t.AddSeconds(day * OneDaySec) +} + +// SubDay add some day time for the time +func (t *Time) SubDay(day int) *Time { + return t.AddSeconds(-day * OneDaySec) +} + +// Tomorrow time. get tomorrow time for the time +func (t *Time) Tomorrow() *Time { + return t.AddSeconds(OneDaySec) +} + +// DayAfter get some day after time for the time. +// alias of Time.AddDay() +func (t *Time) DayAfter(day int) *Time { + return t.AddDay(day) +} + +// AddDur some duration time +func (t *Time) AddDur(dur time.Duration) *Time { + return &Time{ + Time: t.Add(dur), + Layout: DefaultLayout, + } +} + +// AddString add duration time string. +// +// Example: +// +// tn := timex.Now() // example as "2019-01-01 12:12:12" +// nt := tn.AddString("1h") +// nt.Datetime() // Output: 2019-01-01 13:12:12 +func (t *Time) AddString(dur string) *Time { + d, err := ToDuration(dur) + if err != nil { + panic(err) + } + + return t.AddDur(d) +} + +// AddHour add some hour time +func (t *Time) AddHour(hours int) *Time { + return t.AddSeconds(hours * OneHourSec) +} + +// SubHour add some hour time +func (t *Time) SubHour(hours int) *Time { + return t.AddSeconds(-hours * OneHourSec) +} + +// AddMinutes add some minutes time for the time +func (t *Time) AddMinutes(minutes int) *Time { + return t.AddSeconds(minutes * OneMinSec) +} + +// SubMinutes add some minutes time for the time +func (t *Time) SubMinutes(minutes int) *Time { + return t.AddSeconds(-minutes * OneMinSec) +} + +// AddSeconds add some seconds time the time +func (t *Time) AddSeconds(seconds int) *Time { + return &Time{ + Time: t.Add(time.Duration(seconds) * time.Second), + // with layout + Layout: DefaultLayout, + } +} + +// SubSeconds add some seconds time the time +func (t *Time) SubSeconds(seconds int) *Time { + return &Time{ + Time: t.Add(time.Duration(-seconds) * time.Second), + // with layout + Layout: DefaultLayout, + } +} + +// Diff calc diff duration for t - u. +// alias of time.Time.Sub() +func (t *Time) Diff(u time.Time) time.Duration { + return t.Sub(u) +} + +// DiffSec calc diff seconds for t - u +func (t *Time) DiffSec(u time.Time) int { + return int(t.Sub(u) / time.Second) +} + +// DiffUnix calc diff seconds for t.Unix() - u +func (t *Time) DiffUnix(u int64) int { + return int(t.Unix() - u) +} + +// SubUnix calc diff seconds for t - u +func (t *Time) SubUnix(u time.Time) int { + return int(t.Sub(u) / time.Second) +} + +// HourStart time +func (t *Time) HourStart() *Time { + y, m, d := t.Date() + newTime := time.Date(y, m, d, t.Hour(), 0, 0, 0, t.Location()) + + return New(newTime) +} + +// HourEnd time +func (t *Time) HourEnd() *Time { + y, m, d := t.Date() + newTime := time.Date(y, m, d, t.Hour(), 59, 59, int(time.Second-time.Nanosecond), t.Location()) + + return New(newTime) +} + +// DayStart get time at 00:00:00 +func (t *Time) DayStart() *Time { + y, m, d := t.Date() + newTime := time.Date(y, m, d, 0, 0, 0, 0, t.Location()) + + return New(newTime) +} + +// DayEnd get time at 23:59:59 +func (t *Time) DayEnd() *Time { + y, m, d := t.Date() + newTime := time.Date(y, m, d, 23, 59, 59, int(time.Second-time.Nanosecond), t.Location()) + + return New(newTime) +} + +// CustomHMS custom change the hour, minute, second for create new time. +func (t *Time) CustomHMS(hour, min, sec int) *Time { + y, m, d := t.Date() + newTime := time.Date(y, m, d, hour, min, sec, int(time.Second-time.Nanosecond), t.Location()) + + return FromTime(newTime) +} + +// IsBefore the given time +func (t *Time) IsBefore(u time.Time) bool { + return t.Before(u) +} + +// IsBeforeUnix the given unix timestamp +func (t *Time) IsBeforeUnix(ux int64) bool { + return t.Before(time.Unix(ux, 0)) +} + +// IsAfter the given time +func (t *Time) IsAfter(u time.Time) bool { + return t.After(u) +} + +// IsAfterUnix the given unix timestamp +func (t *Time) IsAfterUnix(ux int64) bool { + return t.After(time.Unix(ux, 0)) +} + +// Timestamp value. alias of t.Unix() +func (t *Time) Timestamp() int64 { + return t.Unix() +} + +// HowLongAgo format diff time to string. +func (t *Time) HowLongAgo(before time.Time) string { + return HowLongAgo(t.Unix() - before.Unix()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +// +// Tip: will auto match a format by ParseTime +func (t *Time) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + // Fractional seconds are handled implicitly by Parse. + tt, err := ParseTime(string(data[1 : len(data)-1])) + if err == nil { + t.Time = tt + } + return err +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// +// Tip: will auto match a format by ParseTime +func (t *Time) UnmarshalText(data []byte) error { + // Fractional seconds are handled implicitly by Parse. + tt, err := ParseTime(string(data)) + if err == nil { + t.Time = tt + } + return err +} diff --git a/ntime/template.go b/ntime/template.go new file mode 100644 index 0000000..033f394 --- /dev/null +++ b/ntime/template.go @@ -0,0 +1,43 @@ +package ntime + +import "git.noahlan.cn/noahlan/ntool/nstr/ac" + +var matches = []string{ + "ddd", "dd", "d", + "HH", "hh", "h", // HH:0-23 hh:0-12 + "mm", "m", + "ss", "s", + "yyyy", "yy", "y", + "SSS", + "a", "aa", + "MMMM", "MMM", "MM", "M", + "ZZ", "Z", "zz:zz", "zzzz", "z", + "EEEE", "E", +} + +var replaceWith = []string{ + "_2", "02", "2", + "15", "03", "3", + "04", "4", + "05", "5", + "2006", "06", "06", + "000", + "pm", "PM", + "January", "Jan", "01", "1", + "-0700", "-07", "Z07:00", "Z0700", "MST", + "Monday", "Mon", +} + +// ToLayout converts java date format, +// https://docs.oracle.com/javase/6/docs/api/java/text/SimpleDateFormat.html#rfc822timezone into go date layout +func ToLayout(dateFormat string) string { + acBuilder := ac.NewAhoCorasickBuilder(ac.Opts{ + AsciiCaseInsensitive: false, // 大小写敏感 + MatchOnlyWholeWords: false, + MatchKind: ac.LeftMostLongestMatch, // 最左最长匹配 + DFA: true, + }) + aho := acBuilder.Build(matches) + replacer := ac.NewReplacer(aho) + return replacer.ReplaceAll(dateFormat, replaceWith) +} diff --git a/ntime/ticker.go b/ntime/ticker.go new file mode 100644 index 0000000..5a91ea4 --- /dev/null +++ b/ntime/ticker.go @@ -0,0 +1,78 @@ +package ntime + +import ( + "errors" + "time" +) + +// errTimeout indicates a timeout. +var errTimeout = errors.New("timeout") + +type ( + // Ticker interface wraps the Chan and Stop methods. + Ticker interface { + Chan() <-chan time.Time + Stop() + } + + // FakeTicker interface is used for unit testing. + FakeTicker interface { + Ticker + Done() + Tick() + Wait(d time.Duration) error + } + + fakeTicker struct { + c chan time.Time + done chan struct{} + } + + realTicker struct { + *time.Ticker + } +) + +// NewTicker returns a Ticker. +func NewTicker(d time.Duration) Ticker { + return &realTicker{ + Ticker: time.NewTicker(d), + } +} + +func (rt *realTicker) Chan() <-chan time.Time { + return rt.C +} + +// NewFakeTicker returns a FakeTicker. +func NewFakeTicker() FakeTicker { + return &fakeTicker{ + c: make(chan time.Time, 1), + done: make(chan struct{}, 1), + } +} + +func (ft *fakeTicker) Chan() <-chan time.Time { + return ft.c +} + +func (ft *fakeTicker) Done() { + ft.done <- struct{}{} +} + +func (ft *fakeTicker) Stop() { + close(ft.c) +} + +func (ft *fakeTicker) Tick() { + ft.c <- time.Now() +} + +func (ft *fakeTicker) Wait(d time.Duration) error { + select { + case <-time.After(d): + return errTimeout + case <-ft.done: + return nil + } +} diff --git a/ntime/util.go b/ntime/util.go new file mode 100644 index 0000000..9b69d0a --- /dev/null +++ b/ntime/util.go @@ -0,0 +1,235 @@ +package ntime + +import ( + "fmt" + "git.noahlan.cn/noahlan/ntool/nstr" + "strings" + "time" +) + +// ReprOfDuration returns the string representation of given duration in ms. +func ReprOfDuration(duration time.Duration) string { + return fmt.Sprintf("%.1fms", float32(duration)/float32(time.Millisecond)) +} + +// ElapsedTime calc elapsed time 计算运行时间消耗 单位 ms(毫秒) +func ElapsedTime(startTime time.Time) string { + return fmt.Sprintf("%.3f", time.Since(startTime).Seconds()*1000) +} + +// TryToTime parse a date string or duration string to time.Time. +// +// if s is empty, return zero time. +func TryToTime(s string, bt time.Time) (time.Time, error) { + if s == "" { + return ZeroTime, nil + } + if s == "now" { + return time.Now(), nil + } + + // if s is a duration string, add it to bt(base time) + if IsDuration(s) { + dur, err := ToDuration(s) + if err != nil { + return ZeroTime, err + } + return bt.Add(dur), nil + } + + // as a date string, parse it to time.Time + return ParseTime(s) +} + +// InRange check the dst time is in the range of start and end. +// +// if start is zero, only check dst < end, +// if end is zero, only check dst > start. +func InRange(dst, start, end time.Time) bool { + if start.IsZero() && end.IsZero() { + return false + } + + if start.IsZero() { + return dst.Before(end) + } + if end.IsZero() { + return dst.After(start) + } + + return dst.After(start) && dst.Before(end) +} + +// ParseRangeOpt is the option for ParseRange +type ParseRangeOpt struct { + // BaseTime is the base time for relative time string. + // if is zero, use time.Now() as base time. + BaseTime time.Time + // OneAsEnd is the option for one time range. + // - False: "-1h" => "-1h,0"; "1h" => "+1h, feature" + // - True: "-1h" => "zero,-1h"; "1h" => "zero,1h" + OneAsEnd bool + // AutoSort is the option for sort the time range. + AutoSort bool + // SepChar is the separator char for time range string. default is '~' + SepChar byte + // BeforeFn hook for before parse time string. + BeforeFn func(string) string + // KeywordFn is the function for parse keyword time string. + KeywordFn func(string) (time.Time, time.Time, error) +} + +func ensureOpt(opt *ParseRangeOpt) *ParseRangeOpt { + if opt == nil { + opt = &ParseRangeOpt{BaseTime: time.Now(), SepChar: '~'} + } else { + if opt.BaseTime.IsZero() { + opt.BaseTime = time.Now() + } + if opt.SepChar == 0 { + opt.SepChar = '~' + } + } + + return opt +} + +// ParseRange parse time range expression string to time.Time range. +// - "0" is alias of "now" +// +// Expression format: +// +// "-5h~-1h" => 5 hours ago to 1 hour ago +// "1h~5h" => 1 hour after to 5 hours after +// "-1h~1h" => 1 hour ago to 1 hour after +// "-1h" => 1 hour ago to feature. eq "-1h," +// "-1h~0" => 1 hour ago to now. +// "< -1h" OR "~-1h" => 1 hour ago. eq ",-1h" +// "> 1h" OR "1h" => 1 hour after to feature +// // keyword: now, today, yesterday, tomorrow +// "today" => today start to today end +// "yesterday" => yesterday start to yesterday end +// "tomorrow" => tomorrow start to tomorrow end +// +// Usage: +// +// start, end, err := ParseRange("-1h~1h", nil) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Println(start, end) +func ParseRange(expr string, opt *ParseRangeOpt) (start, end time.Time, err error) { + opt = ensureOpt(opt) + expr = strings.TrimSpace(expr) + if expr == "" { + err = fmt.Errorf("invalid time range expr %q", expr) + return + } + + // parse time range. eg: "5h~1h" + if strings.IndexByte(expr, opt.SepChar) > -1 { + s1, s2 := nstr.TrimCut(expr, string(opt.SepChar)) + if s1 == "" && s2 == "" { + err = fmt.Errorf("invalid time range expr: %s", expr) + return + } + + if s1 != "" { + start, err = TryToTime(s1, opt.BaseTime) + if err != nil { + return + } + } + + if s2 != "" { + end, err = TryToTime(s2, opt.BaseTime) + // auto sort range time + if opt.AutoSort && err == nil { + if !start.IsZero() && start.After(end) { + start, end = end, start + } + } + } + + return + } + + // single time. eg: "5h", "1h", "-1h" + if IsDuration(expr) { + tt, err1 := TryToTime(expr, opt.BaseTime) + if err1 != nil { + err = err1 + return + } + + if opt.OneAsEnd { + end = tt + } else { + start = tt + } + return + } + + // with compare operator. eg: "<1h", ">1h" + if expr[0] == '<' || expr[0] == '>' { + tt, err1 := TryToTime(strings.Trim(expr[1:], " ="), opt.BaseTime) + if err1 != nil { + err = err1 + return + } + + if expr[0] == '<' { + end = tt + } else { + start = tt + } + return + } + + // parse keyword time string + switch expr { + case "0": + if opt.OneAsEnd { + end = opt.BaseTime + } else { + start = opt.BaseTime + } + case "now": + if opt.OneAsEnd { + end = time.Now() + } else { + start = time.Now() + } + case "today": + start = DayStart(opt.BaseTime) + end = DayEnd(opt.BaseTime) + case "yesterday": + yd := opt.BaseTime.AddDate(0, 0, -1) + start = DayStart(yd) + end = DayEnd(yd) + case "tomorrow": + td := opt.BaseTime.AddDate(0, 0, 1) + start = DayStart(td) + end = DayEnd(td) + default: + // single datetime. eg: "2019-01-01" + tt, err1 := TryToTime(expr, opt.BaseTime) + if err1 != nil { + if opt.KeywordFn == nil { + err = fmt.Errorf("invalid keyword time string: %s", expr) + return + } + + start, end, err = opt.KeywordFn(expr) + return + } + + if opt.OneAsEnd { + end = tt + } else { + start = tt + } + } + + return +}