Golang 实现分割 SQL 字符串

在开发中的使用场景:

  • 有一个 SQL 文件,存放了非常多的 SQL 语句,需要程序将其分割成一个一个的 SQL 语句,然后单独一条一条的处理;
  • 在自己实现 Flink 或其它任务执行引擎框架时,用户把想要按顺序执行的 SQL 写到一个文件中,引擎执行时,需要把这个文件中的 SQL 分割成一条条单独运行的 SQL 子语句。

要解决以上问题,需要自己写一个分割 SQL 语句的算法。

可以使用网上开源的库来实现,但是过于臃肿。于是自己用 Golang 实现了一个。

分割 SQL 语句主要通过分号 ; 分割符,但是要考虑到这个分号,是有可能在字符串中的,例如

SELECT a FROM tt WHERE name="kyle;dust;"; SELECT b FROM c WHERE id=1;

还要考虑到 Flink SQL 中,STATEMENT SET 的语法,也就是批量执行多条 SQL,例如

EXECUTE STATEMENT SET
BEGIN
INSERT INTO portrait_feature
SELECT  uid,
        Array[
            Row('location_country', CAST(location_country AS STRING), 'STRING')
            ] as data
FROM X;
END;

BEGIN STATEMENT SET;
INSERT INTO portrait_feature SELECT a FROM b;
INSERT INTO portrait_feature SELECT a FROM b;
END;

源码

代码量不多,只用一个文件搞定 sqlsplit.go

package sqlsplit

import (
	"strings"
	"unicode/utf8"
)

type SplitConfig struct {
	Dialect        string
	StripSemicolon bool
}

type splitOption func(*SplitConfig)

func WithDialect(dialect string) splitOption {
	return func(c *SplitConfig) {
		c.Dialect = dialect
	}
}

func WithStripSemicolon() splitOption {
	return func(c *SplitConfig) {
		c.StripSemicolon = true
	}
}

// Split 将 sql 分割成多个独立的 sql
//
type Split struct {
	// src 是源 sql 字符串
	src string
	// the current position of the cursor
	cursor int
	// the start position of the current statement
	start int
	// 保存最近看到的标识标
	idents *seenIdentifiers
	// 是否进入了 statement set 多语句中
	stmtSet bool
	config  *SplitConfig
}

// New 生成一个 Split 实例
func New(input string, opts ...splitOption) *Split {
	spliter := &Split{
		src:    input,
		idents: newSeenIdentifiers(2),
		config: &SplitConfig{},
	}
	for _, opt := range opts {
		opt(spliter.config)
	}
	return spliter
}

func (s *Split) Split() []string {
	var statements []string
	for {
		statement := s.scan()
		if statement == "" {
			break
		}
		statements = append(statements, statement)
	}
	return statements
}

func (s *Split) stmt() string {
	_ = s.next()

	// 空白符和单行注释
	// 换行符不被当作空白符处理
	for {
		ch := s.peek()
		if isWhitespace(ch) {
			s.scanWhitespace()
			continue
		}
		if isSingleLineComment(ch, s.lookAhead(1)) {
			s.scanSingleLineComment()
			continue
		}
		break
	}

	stmt := s.src[s.start:s.cursor]
	s.start = s.cursor
	stmt = strings.TrimSpace(stmt)
	if s.config.StripSemicolon {
		stmt = strings.TrimSuffix(stmt, ";")
	}
	stmt = strings.TrimSpace(stmt)
	return stmt
}

func (s *Split) scanIdentifier() {
	curPos := s.cursor
	for {
		ch := s.peek()
		if isLetter(ch) || isNumber(ch) || ch == '_' {
			s.next()
		} else {
			break
		}
	}
	identifier := s.src[curPos:s.cursor]
	// 判断 flink 的 statement set 语句
	if strings.ToUpper(identifier) == "SET" && strings.ToUpper(s.idents.current()) == "STATEMENT" {
		s.stmtSet = true
	}
	s.idents.add(identifier)
}

// 遇到分号时调用,如果当前在 statement set 中,则继续扫描,直到遇到 END 关键字
func (s *Split) tryStmt() (string, bool) {
	if s.stmtSet {
		if s.idents.current() == "END" {
			s.stmtSet = false
			return s.stmt(), true
		}
		return "", false
	}
	return s.stmt(), true
}

// scan scans the next statement and returns it.
func (s *Split) scan() string {
	for {
		ch := s.peek()
		switch {
		case isEOF(ch):
			return s.stmt()
		case ch == ';':
			stmt, ok := s.tryStmt()
			if ok {
				return stmt
			}
			_ = s.next()
		case isSingleQuote(ch):
			s.scanString('\'')
		case isDoubleQuote(ch):
			s.scanString('"')
		case isLetter(ch):
			s.scanIdentifier()
		case isSingleLineComment(ch, s.lookAhead(1)):
			s.scanSingleLineComment()
		case isMultiLineComment(ch, s.lookAhead(1)):
			s.scanMultiLineComment()
		default:
			_ = s.next()
		}
	}
}

func (s *Split) scanString(quote rune) {
	ch := s.next() // consume the opening quote
	escaped := false

	for {
		if escaped {
			// encountered an escape character
			// reset the escaped flag and continue
			escaped = false
			ch = s.next()
			continue
		}

		if ch == '\\' {
			escaped = true
			ch = s.next()
			continue
		}

		if ch == quote {
			s.next() // consume the closing quote
			return
		}

		if isEOF(ch) {
			// encountered EOF before closing quote
			// this usually happens when the string is truncated
			return
		}
		ch = s.next()
	}
}

func (s *Split) scanWhitespace() {
	// scan whitespace, tab, carriage return
	ch := s.next()
	for isWhitespace(ch) {
		ch = s.next()
	}
}

func (s *Split) scanSingleLineComment() {
	ch := s.nextBy(2) // consume the opening dashes
	for ch != '\n' && !isEOF(ch) {
		ch = s.next()
	}
	// 把换行符也放到 comment 中
	// 为了与 sqlparse 库逻辑保持一致
	if ch == '\n' {
		_ = s.next()
	}
}

func (s *Split) scanMultiLineComment() {
	ch := s.nextBy(2) // consume the opening slash and asterisk
	for {
		if ch == '*' && s.lookAhead(1) == '/' {
			s.nextBy(2) // consume the closing asterisk and slash
			break
		}
		if isEOF(ch) {
			// encountered EOF before closing comment
			// this usually happens when the comment is truncated
			return
		}
		ch = s.next()
	}
}

// lookAhead returns the rune n positions ahead of the cursor.
func (s *Split) lookAhead(n int) rune {
	if s.cursor+n >= len(s.src) || s.cursor+n < 0 {
		return 0
	}
	r, _ := utf8.DecodeRuneInString(s.src[s.cursor+n:])
	return r
}

// peek returns the rune at the cursor position.
func (s *Split) peek() rune {
	return s.lookAhead(0)
}

// nextBy advances the cursor by n positions and returns the rune at the cursor position.
func (s *Split) nextBy(n int) rune {
	// advance the cursor by n and return the rune at the cursor position
	if s.cursor+n > len(s.src) {
		return 0
	}
	s.cursor += n
	if s.cursor >= len(s.src) {
		return 0
	}
	r, _ := utf8.DecodeRuneInString(s.src[s.cursor:])
	return r
}

// next advances the cursor by 1 position and returns the rune at the cursor position.
func (s *Split) next() rune {
	return s.nextBy(1)
}

func isEOF(ch rune) bool {
	return ch == 0
}

func isSingleLineComment(ch rune, nextCh rune) bool {
	return ch == '-' && nextCh == '-'
}

func isMultiLineComment(ch rune, nextCh rune) bool {
	return ch == '/' && nextCh == '*'
}

func isSingleQuote(ch rune) bool {
	return ch == '\''
}

func isDoubleQuote(ch rune) bool {
	return ch == '"'
}

func isLetter(ch rune) bool {
	return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z')
}

func isNumber(ch rune) bool {
	return ch >= '0' && ch <= '9'
}

// 非换行符的其它空白符
func isWhitespace(ch rune) bool {
	return ch == ' ' || ch == '\t' || ch == '\r'
}

type seenIdentifiers struct {
	size        int
	identifiers []string
	pos         int
}

func newSeenIdentifiers(size int) *seenIdentifiers {
	return &seenIdentifiers{
		size:        size,
		identifiers: make([]string, size),
		pos:         0,
	}
}

func (s *seenIdentifiers) add(identifier string) {
	if s.size == 0 {
		return
	}
	s.identifiers[s.pos] = identifier
	s.pos = (s.pos + 1) % s.size
}

func (s *seenIdentifiers) current() string {
	if s.size == 0 {
		return ""
	}
	return s.identifiers[(s.pos-1+s.size)%s.size]
}

func (s *seenIdentifiers) String() string {
	if s.size == 0 || s.identifiers[0] == "" {
		return ""
	}

	ordered := make([]string, 0, s.size)
	if s.identifiers[s.pos] == "" {
		ordered = append(ordered, s.identifiers[:s.pos]...)
	} else {
		for i := 0; i < s.size; i++ {
			idx := (s.pos + i) % s.size
			ordered = append(ordered, s.identifiers[idx])
		}
	}

	return strings.Join(ordered, " ")
}

也放了一份代码到 Github 上:https://gist.github.com/kylege/88c87d1697dda64662280107e7f853e9

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注