Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 273 additions & 0 deletions internal/compiler/infer_expr_type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
package compiler

import (
"github.com/sqlc-dev/sqlc/internal/config"
"github.com/sqlc-dev/sqlc/internal/sql/ast"
)

//
// ==============================
// Internal Type System
// ==============================
//

type Kind int

const (
KindUnknown Kind = iota // inference not supported
KindInt
KindFloat
KindDecimal
KindAny
)

type Type struct {
Kind Kind
NotNull bool
Valid bool // explicit signal: inference succeeded
}

func unknownType() Type {
return Type{Kind: KindUnknown, Valid: false}
}

//
// ==============================
// Entry Point
// ==============================
//

func (c *Compiler) inferExprType(node ast.Node, tables []*Table) *Column {
if node == nil {
return nil
}

switch c.conf.Engine {
case config.EngineMySQL:
t := c.inferMySQLExpr(node, tables)
return c.mysqlTypeToColumn(t)

// case config.EnginePostgreSQL:
// t := c.inferPostgresExpr(node, tables)
// return c.postgresTypeToColumn(t)

default:
return nil
}
}

//
// ==============================
// MySQL Inference
// ==============================
//

func (c *Compiler) inferMySQLExpr(node ast.Node, tables []*Table) Type {
switch n := node.(type) {
case *ast.ColumnRef:
return c.inferMySQLColumnRef(n, tables)

case *ast.A_Const:
return inferConst(n)

case *ast.TypeCast:
return c.inferMySQLTypeCast(n, tables)

case *ast.A_Expr:
return c.inferMySQLBinary(n, tables)

default:
return unknownType()
}
}

//
// ------------------------------
// Leaf nodes
// ------------------------------
//

func (c *Compiler) inferMySQLColumnRef(ref *ast.ColumnRef, tables []*Table) Type {
cols, err := outputColumnRefs(&ast.ResTarget{}, tables, ref)
if err != nil || len(cols) == 0 {
return unknownType()
}

col := cols[0]

return Type{
Kind: mapMySQLKind(col.DataType),
NotNull: col.NotNull,
Valid: true,
}
}

func inferConst(node *ast.A_Const) Type {
if node == nil || node.Val == nil {
return unknownType()
}

switch node.Val.(type) {
case *ast.Integer:
return Type{Kind: KindInt, NotNull: true, Valid: true}

case *ast.Float:
return Type{Kind: KindFloat, NotNull: true, Valid: true}

case *ast.Null:
return Type{Kind: KindAny, NotNull: false, Valid: true}

default:
return unknownType()
}
}

func (c *Compiler) inferMySQLTypeCast(node *ast.TypeCast, tables []*Table) Type {
if node == nil || node.TypeName == nil {
return unknownType()
}

base := toColumn(node.TypeName)
if base == nil {
return unknownType()
}

arg := c.inferMySQLExpr(node.Arg, tables)

t := Type{
Kind: mapMySQLKind(base.DataType),
Valid: true,
}

// propagate nullability
if arg.Valid {
t.NotNull = arg.NotNull
}

// explicit NULL literal
if constant, ok := node.Arg.(*ast.A_Const); ok {
if _, isNull := constant.Val.(*ast.Null); isNull {
t.NotNull = false
}
}

return t
}

//
// ------------------------------
// Binary expressions
// ------------------------------
//

func (c *Compiler) inferMySQLBinary(node *ast.A_Expr, tables []*Table) Type {
op := joinOperator(node)

left := c.inferMySQLExpr(node.Lexpr, tables)
right := c.inferMySQLExpr(node.Rexpr, tables)

if !left.Valid || !right.Valid {
return unknownType()
}

// NOTE: only normal division ("/") is supported for now.
// Unsupported operators intentionally fall back to the existing behavior.
return promoteMySQLNumeric(op, left, right)
}

//
// ==============================
// Promotion Rules (MySQL-specific for now)
// ==============================
//

// promoteMySQLNumeric applies simplified numeric promotion rules for MySQL.
// It currently only supports "/" and intentionally falls back for other operators.
func promoteMySQLNumeric(op string, a, b Type) Type {
notNull := a.NotNull && b.NotNull

switch op {
case "/":
if a.Kind == KindFloat || b.Kind == KindFloat {
return Type{
Kind: KindFloat,
NotNull: notNull,
Valid: true,
}
}

return Type{
Kind: KindDecimal,
NotNull: notNull,
Valid: true,
}
}

return unknownType()
}

//
// ==============================
// Engine-specific Mapping
// ==============================
//

func (c *Compiler) mysqlTypeToColumn(t Type) *Column {
if !t.Valid {
return nil
}

col := &Column{
NotNull: t.NotNull,
}

switch t.Kind {
case KindInt:
col.DataType = "int"

case KindFloat:
col.DataType = "float"

case KindDecimal:
col.DataType = "decimal"

default:
col.DataType = "any"
}

return col
}

func mapMySQLKind(dt string) Kind {
switch dt {
case "int", "integer", "bigint", "smallint":
return KindInt

case "float", "double", "real":
return KindFloat

case "decimal", "numeric":
return KindDecimal

default:
return KindUnknown
}
}

//
// ==============================
// AST helpers
// ==============================
//

func joinOperator(node *ast.A_Expr) string {
if node == nil || node.Name == nil || len(node.Name.Items) == 0 {
return ""
}

if s, ok := node.Name.Items[0].(*ast.String); ok {
return s.Str
}

return ""
}
8 changes: 7 additions & 1 deletion internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,13 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
// TODO: Generate a name for these operations
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
case lang.IsMathematicalOperator(op):
cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
if inferredCol := c.inferExprType(n, tables); inferredCol != nil {
inferredCol.Name = name
inferredCol.skipTableRequiredCheck = true
cols = append(cols, inferredCol)
} else {
cols = append(cols, &Column{Name: name, DataType: "int", NotNull: true})
}
default:
cols = append(cols, &Column{Name: name, DataType: "any", NotNull: false})
}
Expand Down
4 changes: 3 additions & 1 deletion internal/engine/dolphin/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ func opToName(o opcode.Op) string {
// case opcode.BitNeg:
// case opcode.Case:
// case opcode.Div:
case opcode.Div:
return "/"
case opcode.EQ:
return "="
case opcode.GE:
Expand All @@ -145,7 +147,7 @@ func opToName(o opcode.Op) string {
return ">"
// case opcode.In:
case opcode.IntDiv:
return "/"
return "div"
// case opcode.IsFalsity:
// case opcode.IsNull:
// case opcode.IsTruth:
Expand Down
Loading