diff --git a/net/netfilter/nft_socket.c b/net/netfilter/nft_socket.c index 6d9e8e0a3a7d..d6da68a3b739 100644 --- a/net/netfilter/nft_socket.c +++ b/net/netfilter/nft_socket.c @@ -59,21 +59,27 @@ static void nft_socket_eval(const struct nft_expr *expr, const struct nft_pktinfo *pkt) { const struct nft_socket *priv = nft_expr_priv(expr); + u32 *dest = ®s->data[priv->dreg]; struct sk_buff *skb = pkt->skb; + const struct net_device *dev; struct sock *sk = skb->sk; - u32 *dest = ®s->data[priv->dreg]; if (sk && !net_eq(nft_net(pkt), sock_net(sk))) sk = NULL; - if (!sk) + if (nft_hook(pkt) == NF_INET_LOCAL_OUT) + dev = nft_out(pkt); + else + dev = nft_in(pkt); + + if (!sk) { switch(nft_pf(pkt)) { case NFPROTO_IPV4: - sk = nf_sk_lookup_slow_v4(nft_net(pkt), skb, nft_in(pkt)); + sk = nf_sk_lookup_slow_v4(nft_net(pkt), skb, dev); break; #if IS_ENABLED(CONFIG_NF_TABLES_IPV6) case NFPROTO_IPV6: - sk = nf_sk_lookup_slow_v6(nft_net(pkt), skb, nft_in(pkt)); + sk = nf_sk_lookup_slow_v6(nft_net(pkt), skb, dev); break; #endif default: @@ -81,6 +87,7 @@ static void nft_socket_eval(const struct nft_expr *expr, regs->verdict.code = NFT_BREAK; return; } + } if (!sk) { regs->verdict.code = NFT_BREAK; @@ -184,6 +191,15 @@ static int nft_socket_init(const struct nft_ctx *ctx, NULL, NFT_DATA_VALUE, len); } +static int nft_socket_validate(const struct nft_ctx *ctx, + const struct nft_expr *expr, + const struct nft_data **data) +{ + return nft_chain_validate_hooks(ctx->chain, (1 << NF_INET_PRE_ROUTING) | + (1 << NF_INET_LOCAL_IN) | + (1 << NF_INET_LOCAL_OUT)); +} + static int nft_socket_dump(struct sk_buff *skb, const struct nft_expr *expr) { @@ -230,6 +246,7 @@ static const struct nft_expr_ops nft_socket_ops = { .size = NFT_EXPR_SIZE(sizeof(struct nft_socket)), .eval = nft_socket_eval, .init = nft_socket_init, + .validate = nft_socket_validate, .dump = nft_socket_dump, .reduce = nft_socket_reduce, };