在开发中的使用场景:
- 有一个 SQL 文件,存放了非常多的 SQL 语句,需要程序将其分割成一个一个的 SQL 语句,然后单独一条一条的处理;
- 在自己实现 Flink 或其它任务执行引擎框架时,用户把想要按顺序执行的 SQL 写到一个文件中,引擎执行时,需要把这个文件中的 SQL 分割成一条条单独运行的 SQL 子语句。
要解决以上问题,需要自己写一个分割 SQL 语句的算法。
可以使用网上开源的库来实现,但是过于臃肿。于是自己用 Golang 实现了一个。
分割 SQL 语句主要通过分号 ; 分割符,但是要考虑到这个分号,是有可能在字符串中的,例如
SELECT a FROM tt WHERE name="kyle;dust;"; SELECT b FROM c WHERE id=1;
还要考虑到 Flink SQL 中,STATEMENT SET 的语法,也就是批量执行多条 SQL,例如
EXECUTE STATEMENT SET
BEGIN
INSERT INTO portrait_feature
SELECT uid,
Array[
Row('location_country', CAST(location_country AS STRING), 'STRING')
] as data
FROM X;
END;
BEGIN STATEMENT SET;
INSERT INTO portrait_feature SELECT a FROM b;
INSERT INTO portrait_feature SELECT a FROM b;
END;
源码
代码量不多,只用一个文件搞定 sqlsplit.go
package sqlsplit
import (
"strings"
"unicode/utf8"
)
type SplitConfig struct {
Dialect string
StripSemicolon bool
}
type splitOption func(*SplitConfig)
func WithDialect(dialect string) splitOption {
return func(c *SplitConfig) {
c.Dialect = dialect
}
}
func WithStripSemicolon() splitOption {
return func(c *SplitConfig) {
c.StripSemicolon = true
}
}
// Split 将 sql 分割成多个独立的 sql
//
type Split struct {
// src 是源 sql 字符串
src string
// the current position of the cursor
cursor int
// the start position of the current statement
start int
// 保存最近看到的标识标
idents *seenIdentifiers
// 是否进入了 statement set 多语句中
stmtSet bool
config *SplitConfig
}
// New 生成一个 Split 实例
func New(input string, opts ...splitOption) *Split {
spliter := &Split{
src: input,
idents: newSeenIdentifiers(2),
config: &SplitConfig{},
}
for _, opt := range opts {
opt(spliter.config)
}
return spliter
}
func (s *Split) Split() []string {
var statements []string
for {
statement := s.scan()
if statement == "" {
break
}
statements = append(statements, statement)
}
return statements
}
func (s *Split) stmt() string {
_ = s.next()
// 空白符和单行注释
// 换行符不被当作空白符处理
for {
ch := s.peek()
if isWhitespace(ch) {
s.scanWhitespace()
continue
}
if isSingleLineComment(ch, s.lookAhead(1)) {
s.scanSingleLineComment()
continue
}
break
}
stmt := s.src[s.start:s.cursor]
s.start = s.cursor
stmt = strings.TrimSpace(stmt)
if s.config.StripSemicolon {
stmt = strings.TrimSuffix(stmt, ";")
}
stmt = strings.TrimSpace(stmt)
return stmt
}
func (s *Split) scanIdentifier() {
curPos := s.cursor
for {
ch := s.peek()
if isLetter(ch) || isNumber(ch) || ch == '_' {
s.next()
} else {
break
}
}
identifier := s.src[curPos:s.cursor]
// 判断 flink 的 statement set 语句
if strings.ToUpper(identifier) == "SET" && strings.ToUpper(s.idents.current()) == "STATEMENT" {
s.stmtSet = true
}
s.idents.add(identifier)
}
// 遇到分号时调用,如果当前在 statement set 中,则继续扫描,直到遇到 END 关键字
func (s *Split) tryStmt() (string, bool) {
if s.stmtSet {
if s.idents.current() == "END" {
s.stmtSet = false
return s.stmt(), true
}
return "", false
}
return s.stmt(), true
}
// scan scans the next statement and returns it.
func (s *Split) scan() string {
for {
ch := s.peek()
switch {
case isEOF(ch):
return s.stmt()
case ch == ';':
stmt, ok := s.tryStmt()
if ok {
return stmt
}
_ = s.next()
case isSingleQuote(ch):
s.scanString('\'')
case isDoubleQuote(ch):
s.scanString('"')
case isLetter(ch):
s.scanIdentifier()
case isSingleLineComment(ch, s.lookAhead(1)):
s.scanSingleLineComment()
case isMultiLineComment(ch, s.lookAhead(1)):
s.scanMultiLineComment()
default:
_ = s.next()
}
}
}
func (s *Split) scanString(quote rune) {
ch := s.next() // consume the opening quote
escaped := false
for {
if escaped {
// encountered an escape character
// reset the escaped flag and continue
escaped = false
ch = s.next()
continue
}
if ch == '\\' {
escaped = true
ch = s.next()
continue
}
if ch == quote {
s.next() // consume the closing quote
return
}
if isEOF(ch) {
// encountered EOF before closing quote
// this usually happens when the string is truncated
return
}
ch = s.next()
}
}
func (s *Split) scanWhitespace() {
// scan whitespace, tab, carriage return
ch := s.next()
for isWhitespace(ch) {
ch = s.next()
}
}
func (s *Split) scanSingleLineComment() {
ch := s.nextBy(2) // consume the opening dashes
for ch != '\n' && !isEOF(ch) {
ch = s.next()
}
// 把换行符也放到 comment 中
// 为了与 sqlparse 库逻辑保持一致
if ch == '\n' {
_ = s.next()
}
}
func (s *Split) scanMultiLineComment() {
ch := s.nextBy(2) // consume the opening slash and asterisk
for {
if ch == '*' && s.lookAhead(1) == '/' {
s.nextBy(2) // consume the closing asterisk and slash
break
}
if isEOF(ch) {
// encountered EOF before closing comment
// this usually happens when the comment is truncated
return
}
ch = s.next()
}
}
// lookAhead returns the rune n positions ahead of the cursor.
func (s *Split) lookAhead(n int) rune {
if s.cursor+n >= len(s.src) || s.cursor+n < 0 {
return 0
}
r, _ := utf8.DecodeRuneInString(s.src[s.cursor+n:])
return r
}
// peek returns the rune at the cursor position.
func (s *Split) peek() rune {
return s.lookAhead(0)
}
// nextBy advances the cursor by n positions and returns the rune at the cursor position.
func (s *Split) nextBy(n int) rune {
// advance the cursor by n and return the rune at the cursor position
if s.cursor+n > len(s.src) {
return 0
}
s.cursor += n
if s.cursor >= len(s.src) {
return 0
}
r, _ := utf8.DecodeRuneInString(s.src[s.cursor:])
return r
}
// next advances the cursor by 1 position and returns the rune at the cursor position.
func (s *Split) next() rune {
return s.nextBy(1)
}
func isEOF(ch rune) bool {
return ch == 0
}
func isSingleLineComment(ch rune, nextCh rune) bool {
return ch == '-' && nextCh == '-'
}
func isMultiLineComment(ch rune, nextCh rune) bool {
return ch == '/' && nextCh == '*'
}
func isSingleQuote(ch rune) bool {
return ch == '\''
}
func isDoubleQuote(ch rune) bool {
return ch == '"'
}
func isLetter(ch rune) bool {
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}
func isNumber(ch rune) bool {
return ch >= '0' && ch <= '9'
}
// 非换行符的其它空白符
func isWhitespace(ch rune) bool {
return ch == ' ' || ch == '\t' || ch == '\r'
}
type seenIdentifiers struct {
size int
identifiers []string
pos int
}
func newSeenIdentifiers(size int) *seenIdentifiers {
return &seenIdentifiers{
size: size,
identifiers: make([]string, size),
pos: 0,
}
}
func (s *seenIdentifiers) add(identifier string) {
if s.size == 0 {
return
}
s.identifiers[s.pos] = identifier
s.pos = (s.pos + 1) % s.size
}
func (s *seenIdentifiers) current() string {
if s.size == 0 {
return ""
}
return s.identifiers[(s.pos-1+s.size)%s.size]
}
func (s *seenIdentifiers) String() string {
if s.size == 0 || s.identifiers[0] == "" {
return ""
}
ordered := make([]string, 0, s.size)
if s.identifiers[s.pos] == "" {
ordered = append(ordered, s.identifiers[:s.pos]...)
} else {
for i := 0; i < s.size; i++ {
idx := (s.pos + i) % s.size
ordered = append(ordered, s.identifiers[idx])
}
}
return strings.Join(ordered, " ")
}
也放了一份代码到 Github 上:https://gist.github.com/kylege/88c87d1697dda64662280107e7f853e9